feat: Portal, Email Inbound, Discuss + module improvements
- Portal: /my/* routes, signup, password reset, portal user support - Email Inbound: IMAP polling (go-imap/v2), thread matching - Discuss: mail.channel, long-polling bus, DM, unread count - Cron: ir.cron runner (goroutine scheduler) - Bank Import, CSV/Excel Import - Automation (ir.actions.server) - Fetchmail service - HR Payroll model - Various fixes across account, sale, stock, purchase, crm, hr, project Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
package orm
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
// ComputeFunc is a function that computes field values for a recordset.
|
||||
// Mirrors: @api.depends decorated methods in Odoo.
|
||||
@@ -253,7 +256,7 @@ func RunOnchangeComputes(m *Model, env *Environment, currentVals Values, changed
|
||||
|
||||
computed, err := fn(rs)
|
||||
if err != nil {
|
||||
// Non-fatal: skip failed computes during onchange
|
||||
log.Printf("orm: onchange compute %s.%s failed: %v", m.Name(), fieldName, err)
|
||||
continue
|
||||
}
|
||||
for k, v := range computed {
|
||||
|
||||
@@ -2,6 +2,8 @@ package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -152,6 +154,8 @@ func (dc *DomainCompiler) JoinSQL() string {
|
||||
return " " + strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// compileNodes compiles domain nodes in Polish (prefix) notation.
|
||||
// Returns the SQL string and the number of nodes consumed from the domain starting at pos.
|
||||
func (dc *DomainCompiler) compileNodes(domain Domain, pos int) (string, error) {
|
||||
if pos >= len(domain) {
|
||||
return "TRUE", nil
|
||||
@@ -167,7 +171,8 @@ func (dc *DomainCompiler) compileNodes(domain Domain, pos int) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
right, err := dc.compileNodes(domain, pos+2)
|
||||
leftSize := nodeSize(domain, pos+1)
|
||||
right, err := dc.compileNodes(domain, pos+1+leftSize)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -178,7 +183,8 @@ func (dc *DomainCompiler) compileNodes(domain Domain, pos int) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
right, err := dc.compileNodes(domain, pos+2)
|
||||
leftSize := nodeSize(domain, pos+1)
|
||||
right, err := dc.compileNodes(domain, pos+1+leftSize)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -196,8 +202,6 @@ func (dc *DomainCompiler) compileNodes(domain Domain, pos int) (string, error) {
|
||||
return dc.compileCondition(n)
|
||||
|
||||
case domainGroup:
|
||||
// domainGroup wraps a sub-domain as a single node.
|
||||
// Compile it recursively as a full domain.
|
||||
subSQL, _, err := dc.compileDomainGroup(Domain(n))
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -208,6 +212,28 @@ func (dc *DomainCompiler) compileNodes(domain Domain, pos int) (string, error) {
|
||||
return "", fmt.Errorf("unexpected domain node at position %d: %v", pos, node)
|
||||
}
|
||||
|
||||
// nodeSize returns the number of domain nodes consumed by the subtree at pos.
|
||||
// Operators (&, |) consume 1 + left subtree + right subtree.
|
||||
// NOT consumes 1 + inner subtree. Leaf nodes consume 1.
|
||||
func nodeSize(domain Domain, pos int) int {
|
||||
if pos >= len(domain) {
|
||||
return 0
|
||||
}
|
||||
switch n := domain[pos].(type) {
|
||||
case Operator:
|
||||
_ = n
|
||||
switch domain[pos].(Operator) {
|
||||
case OpAnd, OpOr:
|
||||
leftSize := nodeSize(domain, pos+1)
|
||||
rightSize := nodeSize(domain, pos+1+leftSize)
|
||||
return 1 + leftSize + rightSize
|
||||
case OpNot:
|
||||
return 1 + nodeSize(domain, pos+1)
|
||||
}
|
||||
}
|
||||
return 1 // Condition or domainGroup = 1 node
|
||||
}
|
||||
|
||||
// compileDomainGroup compiles a sub-domain that was wrapped via domainGroup.
|
||||
// It reuses the same DomainCompiler (sharing params and joins) so parameter
|
||||
// indices stay consistent with the outer query.
|
||||
@@ -227,14 +253,12 @@ func (dc *DomainCompiler) compileCondition(c Condition) (string, error) {
|
||||
return "", fmt.Errorf("invalid operator: %q", c.Operator)
|
||||
}
|
||||
|
||||
// Handle dot notation (e.g., "partner_id.name")
|
||||
// Handle dot notation (e.g., "partner_id.name", "partner_id.country_id.code")
|
||||
// by generating LEFT JOINs through the M2O relational chain.
|
||||
parts := strings.Split(c.Field, ".")
|
||||
column := parts[0]
|
||||
|
||||
// TODO: Handle JOINs for dot notation paths
|
||||
// For now, only support direct fields
|
||||
if len(parts) > 1 {
|
||||
// Placeholder for JOIN resolution
|
||||
return dc.compileJoinedCondition(parts, c.Operator, c.Value)
|
||||
}
|
||||
|
||||
@@ -285,7 +309,7 @@ func (dc *DomainCompiler) compileJoinedCondition(fieldPath []string, operator st
|
||||
dc.joins = append(dc.joins, joinClause{
|
||||
table: comodel.Table(),
|
||||
alias: alias,
|
||||
on: fmt.Sprintf("%s.%q = %q.\"id\"", currentAlias, f.Column(), alias),
|
||||
on: fmt.Sprintf("%q.%q = %q.\"id\"", currentAlias, f.Column(), alias),
|
||||
})
|
||||
|
||||
currentModel = comodel
|
||||
@@ -293,8 +317,12 @@ func (dc *DomainCompiler) compileJoinedCondition(fieldPath []string, operator st
|
||||
}
|
||||
|
||||
// The last segment is the actual field to filter on
|
||||
leafField := fieldPath[len(fieldPath)-1]
|
||||
qualifiedColumn := fmt.Sprintf("%s.%q", currentAlias, leafField)
|
||||
leafFieldName := fieldPath[len(fieldPath)-1]
|
||||
leafCol := leafFieldName
|
||||
if lf := currentModel.GetField(leafFieldName); lf != nil {
|
||||
leafCol = lf.Column()
|
||||
}
|
||||
qualifiedColumn := fmt.Sprintf("%q.%q", currentAlias, leafCol)
|
||||
|
||||
return dc.compileQualifiedCondition(qualifiedColumn, operator, value)
|
||||
}
|
||||
@@ -528,13 +556,8 @@ func (dc *DomainCompiler) compileAnyOp(column string, value Value, negate bool)
|
||||
// Rebase parameter indices: shift them by the current param count
|
||||
baseIdx := len(dc.params)
|
||||
dc.params = append(dc.params, subParams...)
|
||||
rebased := subWhere
|
||||
// Replace $N with $(N+baseIdx) in the sub-where clause
|
||||
for i := len(subParams); i >= 1; i-- {
|
||||
old := fmt.Sprintf("$%d", i)
|
||||
new := fmt.Sprintf("$%d", i+baseIdx)
|
||||
rebased = strings.ReplaceAll(rebased, old, new)
|
||||
}
|
||||
// Replace $N with $(N+baseIdx) using regex to avoid $1 matching $10
|
||||
rebased := rebaseParams(subWhere, baseIdx)
|
||||
|
||||
// Determine the join condition based on field type
|
||||
var joinCond string
|
||||
@@ -676,3 +699,14 @@ func wrapLikeValue(value Value) Value {
|
||||
}
|
||||
return "%" + s + "%"
|
||||
}
|
||||
|
||||
// rebaseParams shifts $N placeholders in a SQL string by baseIdx.
|
||||
// Uses regex to avoid $1 matching inside $10.
|
||||
var paramRegex = regexp.MustCompile(`\$(\d+)`)
|
||||
|
||||
func rebaseParams(sql string, baseIdx int) string {
|
||||
return paramRegex.ReplaceAllStringFunc(sql, func(match string) string {
|
||||
n, _ := strconv.Atoi(match[1:])
|
||||
return fmt.Sprintf("$%d", n+baseIdx)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -45,8 +45,9 @@ type Model struct {
|
||||
checkCompany bool // Enforce multi-company record rules
|
||||
|
||||
// Hooks
|
||||
BeforeCreate func(env *Environment, vals Values) error // Called before INSERT
|
||||
DefaultGet func(env *Environment, fields []string) Values // Dynamic defaults (e.g., from DB)
|
||||
BeforeCreate func(env *Environment, vals Values) error // Called before INSERT
|
||||
BeforeWrite func(env *Environment, ids []int64, vals Values) error // Called before UPDATE — for state guards
|
||||
DefaultGet func(env *Environment, fields []string) Values // Dynamic defaults (e.g., from DB)
|
||||
Constraints []ConstraintFunc // Validation constraints
|
||||
Methods map[string]MethodFunc // Named business methods
|
||||
|
||||
@@ -453,3 +454,32 @@ func (m *Model) Many2manyTableSQL() []string {
|
||||
}
|
||||
return stmts
|
||||
}
|
||||
|
||||
// StateGuard returns a BeforeWrite function that prevents modifications on records
|
||||
// in certain states, except for explicitly allowed fields.
|
||||
// Eliminates the duplicated guard pattern across sale.order, purchase.order,
|
||||
// account.move, and stock.picking.
|
||||
func StateGuard(table, stateCondition string, allowedFields []string, errMsg string) func(env *Environment, ids []int64, vals Values) error {
|
||||
allowed := make(map[string]bool, len(allowedFields))
|
||||
for _, f := range allowedFields {
|
||||
allowed[f] = true
|
||||
}
|
||||
return func(env *Environment, ids []int64, vals Values) error {
|
||||
if _, changingState := vals["state"]; changingState {
|
||||
return nil
|
||||
}
|
||||
var count int
|
||||
err := env.Tx().QueryRow(env.Ctx(),
|
||||
fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE id = ANY($1) AND %s`, table, stateCondition), ids,
|
||||
).Scan(&count)
|
||||
if err != nil || count == 0 {
|
||||
return nil
|
||||
}
|
||||
for field := range vals {
|
||||
if !allowed[field] {
|
||||
return fmt.Errorf("%s: %s", table, errMsg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,7 +153,7 @@ func (rs *Recordset) ReadGroup(domain Domain, groupby []string, aggregates []str
|
||||
// Build ORDER BY
|
||||
orderSQL := ""
|
||||
if opt.Order != "" {
|
||||
orderSQL = opt.Order
|
||||
orderSQL = sanitizeOrderBy(opt.Order, m)
|
||||
} else if len(gbCols) > 0 {
|
||||
// Default: order by groupby columns
|
||||
var orderParts []string
|
||||
|
||||
@@ -2,6 +2,7 @@ package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -265,18 +266,28 @@ func preprocessRelatedWrites(env *Environment, m *Model, ids []int64, vals Value
|
||||
value := vals[fieldName]
|
||||
delete(vals, fieldName) // Remove from vals — no local column
|
||||
|
||||
// Read FK IDs for all records
|
||||
// Read FK IDs for all records in a single query
|
||||
var fkIDs []int64
|
||||
for _, id := range ids {
|
||||
var fkID *int64
|
||||
env.tx.QueryRow(env.ctx,
|
||||
fmt.Sprintf(`SELECT %q FROM %q WHERE id = $1`, fkDef.Column(), m.Table()),
|
||||
id,
|
||||
).Scan(&fkID)
|
||||
if fkID != nil && *fkID > 0 {
|
||||
fkIDs = append(fkIDs, *fkID)
|
||||
rows, err := env.tx.Query(env.ctx,
|
||||
fmt.Sprintf(`SELECT %q FROM %q WHERE id = ANY($1) AND %q IS NOT NULL`,
|
||||
fkDef.Column(), m.Table(), fkDef.Column()),
|
||||
ids,
|
||||
)
|
||||
if err != nil {
|
||||
delete(vals, fieldName)
|
||||
continue
|
||||
}
|
||||
for rows.Next() {
|
||||
var fkID int64
|
||||
if err := rows.Scan(&fkID); err != nil {
|
||||
log.Printf("orm: preprocessRelatedWrites scan error on %s.%s: %v", m.Name(), fieldName, err)
|
||||
continue
|
||||
}
|
||||
if fkID > 0 {
|
||||
fkIDs = append(fkIDs, fkID)
|
||||
}
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
if len(fkIDs) == 0 {
|
||||
continue
|
||||
@@ -315,6 +326,13 @@ func (rs *Recordset) Write(vals Values) error {
|
||||
|
||||
m := rs.model
|
||||
|
||||
// BeforeWrite hook — state guards, locked record checks etc.
|
||||
if m.BeforeWrite != nil {
|
||||
if err := m.BeforeWrite(rs.env, rs.ids, vals); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var setClauses []string
|
||||
var args []interface{}
|
||||
idx := 1
|
||||
@@ -787,7 +805,7 @@ func (rs *Recordset) Search(domain Domain, opts ...SearchOpts) (*Recordset, erro
|
||||
// Build query
|
||||
order := m.order
|
||||
if opt.Order != "" {
|
||||
order = opt.Order
|
||||
order = sanitizeOrderBy(opt.Order, m)
|
||||
}
|
||||
|
||||
joinSQL := compiler.JoinSQL()
|
||||
@@ -1103,6 +1121,72 @@ func toRecordID(v interface{}) (int64, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// sanitizeOrderBy validates an ORDER BY clause to prevent SQL injection.
|
||||
// Only allows: field names (alphanumeric + underscore), ASC/DESC, NULLS FIRST/LAST, commas.
|
||||
// Returns sanitized string or fallback to "id" if invalid.
|
||||
func sanitizeOrderBy(order string, m *Model) string {
|
||||
if order == "" {
|
||||
return "id"
|
||||
}
|
||||
parts := strings.Split(order, ",")
|
||||
var safe []string
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
tokens := strings.Fields(part)
|
||||
if len(tokens) == 0 {
|
||||
continue
|
||||
}
|
||||
// First token must be a valid field name or "table"."field"
|
||||
col := tokens[0]
|
||||
// Strip quotes for validation
|
||||
cleanCol := strings.ReplaceAll(strings.ReplaceAll(col, "\"", ""), "'", "")
|
||||
// Allow dot notation (table.field) but validate each part
|
||||
colParts := strings.Split(cleanCol, ".")
|
||||
valid := true
|
||||
for _, cp := range colParts {
|
||||
if !isValidIdentifier(cp) {
|
||||
valid = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
continue // Skip this part entirely
|
||||
}
|
||||
// Remaining tokens must be ASC, DESC, NULLS, FIRST, LAST
|
||||
safePart := col
|
||||
for _, tok := range tokens[1:] {
|
||||
upper := strings.ToUpper(tok)
|
||||
switch upper {
|
||||
case "ASC", "DESC", "NULLS", "FIRST", "LAST":
|
||||
safePart += " " + upper
|
||||
default:
|
||||
// Invalid token — skip
|
||||
}
|
||||
}
|
||||
safe = append(safe, safePart)
|
||||
}
|
||||
if len(safe) == 0 {
|
||||
return "id"
|
||||
}
|
||||
return strings.Join(safe, ", ")
|
||||
}
|
||||
|
||||
// isValidIdentifier checks if a string is a valid SQL identifier (letters, digits, underscore).
|
||||
func isValidIdentifier(s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for _, c := range s {
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// qualifyOrderBy prefixes unqualified column names with the table name.
|
||||
// "name, id desc" → "\"my_table\".name, \"my_table\".id desc"
|
||||
func qualifyOrderBy(table, order string) string {
|
||||
|
||||
@@ -70,8 +70,9 @@ func ApplyRecordRules(env *Environment, m *Model, domain Domain) Domain {
|
||||
ORDER BY r.id`,
|
||||
m.Name(), env.UID())
|
||||
if err != nil {
|
||||
log.Printf("orm: ir.rule query failed for %s: %v — denying access", m.Name(), err)
|
||||
sp.Rollback(env.ctx)
|
||||
return domain
|
||||
return append(domain, Leaf("id", "=", -1)) // Deny all — no records match id=-1
|
||||
}
|
||||
|
||||
type ruleRow struct {
|
||||
@@ -207,7 +208,8 @@ func CheckRecordRuleAccess(env *Environment, m *Model, ids []int64, perm string)
|
||||
var count int64
|
||||
err := env.tx.QueryRow(env.ctx, query, args...).Scan(&count)
|
||||
if err != nil {
|
||||
return nil // Fail open on error
|
||||
log.Printf("orm: record rule check failed for %s: %v", m.Name(), err)
|
||||
return fmt.Errorf("orm: access denied on %s (record rule check failed)", m.Name())
|
||||
}
|
||||
|
||||
if count < int64(len(ids)) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
@@ -145,10 +146,12 @@ func (s *Server) handleActionLoad(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Look up xml_id from ir_model_data
|
||||
xmlID := ""
|
||||
_ = s.pool.QueryRow(ctx,
|
||||
if err := s.pool.QueryRow(ctx,
|
||||
`SELECT module || '.' || name FROM ir_model_data
|
||||
WHERE model = 'ir.actions.act_window' AND res_id = $1
|
||||
LIMIT 1`, id).Scan(&xmlID)
|
||||
LIMIT 1`, id).Scan(&xmlID); err != nil {
|
||||
log.Printf("warning: action xml_id lookup failed for id=%d: %v", id, err)
|
||||
}
|
||||
|
||||
// Build views array from view_mode string (e.g. "list,kanban,form" → [[nil,"list"],[nil,"kanban"],[nil,"form"]])
|
||||
views := buildViewsFromMode(viewMode)
|
||||
|
||||
292
pkg/server/bank_import.go
Normal file
292
pkg/server/bank_import.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"odoo-go/pkg/orm"
|
||||
)
|
||||
|
||||
// handleBankStatementImport imports bank statement lines from CSV data.
|
||||
// Accepts JSON body with: journal_id, csv_data, column_mapping, has_header.
|
||||
// After import, optionally triggers auto-matching against open invoices.
|
||||
// Mirrors: odoo/addons/account/wizard/account_bank_statement_import.py
|
||||
func (s *Server) handleBankStatementImport(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req JSONRPCRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.writeJSONRPC(w, nil, nil, &RPCError{Code: -32700, Message: "Parse error"})
|
||||
return
|
||||
}
|
||||
|
||||
var params struct {
|
||||
JournalID int64 `json:"journal_id"`
|
||||
CSVData string `json:"csv_data"`
|
||||
HasHeader bool `json:"has_header"`
|
||||
ColumnMapping bankColumnMapping `json:"column_mapping"`
|
||||
AutoMatch bool `json:"auto_match"`
|
||||
}
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "Invalid params"})
|
||||
return
|
||||
}
|
||||
|
||||
if params.JournalID == 0 || params.CSVData == "" {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "journal_id and csv_data are required"})
|
||||
return
|
||||
}
|
||||
|
||||
uid := int64(1)
|
||||
companyID := int64(1)
|
||||
if sess := GetSession(r); sess != nil {
|
||||
uid = sess.UID
|
||||
companyID = sess.CompanyID
|
||||
}
|
||||
|
||||
env, err := orm.NewEnvironment(r.Context(), orm.EnvConfig{
|
||||
Pool: s.pool,
|
||||
UID: uid,
|
||||
CompanyID: companyID,
|
||||
})
|
||||
if err != nil {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32603, Message: "Internal error"})
|
||||
return
|
||||
}
|
||||
defer env.Close()
|
||||
|
||||
// Parse CSV
|
||||
reader := csv.NewReader(strings.NewReader(params.CSVData))
|
||||
reader.LazyQuotes = true
|
||||
reader.TrimLeadingSpace = true
|
||||
// Try semicolon separator (common in European bank exports)
|
||||
reader.Comma = detectDelimiter(params.CSVData)
|
||||
|
||||
var allRows [][]string
|
||||
for {
|
||||
row, err := reader.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: fmt.Sprintf("CSV parse error: %v", err)})
|
||||
return
|
||||
}
|
||||
allRows = append(allRows, row)
|
||||
}
|
||||
|
||||
dataRows := allRows
|
||||
if params.HasHeader && len(allRows) > 1 {
|
||||
dataRows = allRows[1:]
|
||||
}
|
||||
|
||||
// Create a bank statement header
|
||||
statementRS := env.Model("account.bank.statement")
|
||||
stmt, err := statementRS.Create(orm.Values{
|
||||
"name": fmt.Sprintf("Import %s", time.Now().Format("2006-01-02 15:04")),
|
||||
"journal_id": params.JournalID,
|
||||
"company_id": companyID,
|
||||
"date": time.Now().Format("2006-01-02"),
|
||||
})
|
||||
if err != nil {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32603, Message: fmt.Sprintf("Create statement: %v", err)})
|
||||
return
|
||||
}
|
||||
stmtID := stmt.ID()
|
||||
|
||||
// Default column mapping
|
||||
cm := params.ColumnMapping
|
||||
if cm.Date < 0 {
|
||||
cm.Date = 0
|
||||
}
|
||||
if cm.Amount < 0 {
|
||||
cm.Amount = 1
|
||||
}
|
||||
if cm.Label < 0 {
|
||||
cm.Label = 2
|
||||
}
|
||||
|
||||
// Import lines
|
||||
lineRS := env.Model("account.bank.statement.line")
|
||||
var importedIDs []int64
|
||||
var errors []importError
|
||||
|
||||
for rowIdx, row := range dataRows {
|
||||
// Parse date
|
||||
dateStr := safeCol(row, cm.Date)
|
||||
date := parseFlexDate(dateStr)
|
||||
if date == "" {
|
||||
date = time.Now().Format("2006-01-02")
|
||||
}
|
||||
|
||||
// Parse amount
|
||||
amountStr := safeCol(row, cm.Amount)
|
||||
amount := parseAmount(amountStr)
|
||||
if amount == 0 {
|
||||
continue // skip zero-amount rows
|
||||
}
|
||||
|
||||
// Parse label/reference
|
||||
label := safeCol(row, cm.Label)
|
||||
if label == "" {
|
||||
label = fmt.Sprintf("Line %d", rowIdx+1)
|
||||
}
|
||||
|
||||
// Parse optional columns
|
||||
partnerName := safeCol(row, cm.PartnerName)
|
||||
accountNumber := safeCol(row, cm.AccountNumber)
|
||||
|
||||
vals := orm.Values{
|
||||
"statement_id": stmtID,
|
||||
"journal_id": params.JournalID,
|
||||
"company_id": companyID,
|
||||
"date": date,
|
||||
"amount": amount,
|
||||
"payment_ref": label,
|
||||
"partner_name": partnerName,
|
||||
"account_number": accountNumber,
|
||||
"sequence": rowIdx + 1,
|
||||
}
|
||||
|
||||
rec, err := lineRS.Create(vals)
|
||||
if err != nil {
|
||||
errors = append(errors, importError{Row: rowIdx + 1, Message: err.Error()})
|
||||
log.Printf("bank_import: row %d error: %v", rowIdx+1, err)
|
||||
continue
|
||||
}
|
||||
importedIDs = append(importedIDs, rec.ID())
|
||||
}
|
||||
|
||||
// Auto-match against open invoices
|
||||
matchCount := 0
|
||||
if params.AutoMatch && len(importedIDs) > 0 {
|
||||
stLineModel := orm.Registry.Get("account.bank.statement.line")
|
||||
if stLineModel != nil {
|
||||
if matchMethod, ok := stLineModel.Methods["button_match"]; ok {
|
||||
matchRS := env.Model("account.bank.statement.line").Browse(importedIDs...)
|
||||
if _, err := matchMethod(matchRS); err != nil {
|
||||
log.Printf("bank_import: auto-match error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Count how many were matched
|
||||
env.Tx().QueryRow(env.Ctx(),
|
||||
`SELECT COUNT(*) FROM account_bank_statement_line WHERE id = ANY($1) AND is_reconciled = true`,
|
||||
importedIDs).Scan(&matchCount)
|
||||
}
|
||||
|
||||
if err := env.Commit(); err != nil {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32603, Message: fmt.Sprintf("Commit: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
s.writeJSONRPC(w, req.ID, map[string]interface{}{
|
||||
"statement_id": stmtID,
|
||||
"imported": len(importedIDs),
|
||||
"matched": matchCount,
|
||||
"errors": errors,
|
||||
}, nil)
|
||||
}
|
||||
|
||||
// bankColumnMapping maps CSV columns to bank statement fields.
|
||||
type bankColumnMapping struct {
|
||||
Date int `json:"date"` // column index for date
|
||||
Amount int `json:"amount"` // column index for amount
|
||||
Label int `json:"label"` // column index for label/reference
|
||||
PartnerName int `json:"partner_name"` // column index for partner name (-1 = skip)
|
||||
AccountNumber int `json:"account_number"` // column index for account number (-1 = skip)
|
||||
}
|
||||
|
||||
// detectDelimiter guesses the CSV delimiter (comma, semicolon, or tab).
|
||||
func detectDelimiter(data string) rune {
|
||||
firstLine := data
|
||||
if idx := strings.IndexByte(data, '\n'); idx > 0 {
|
||||
firstLine = data[:idx]
|
||||
}
|
||||
semicolons := strings.Count(firstLine, ";")
|
||||
commas := strings.Count(firstLine, ",")
|
||||
tabs := strings.Count(firstLine, "\t")
|
||||
|
||||
if semicolons > commas && semicolons > tabs {
|
||||
return ';'
|
||||
}
|
||||
if tabs > commas {
|
||||
return '\t'
|
||||
}
|
||||
return ','
|
||||
}
|
||||
|
||||
// safeCol returns the value at index i, or "" if out of bounds.
|
||||
func safeCol(row []string, i int) string {
|
||||
if i < 0 || i >= len(row) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(row[i])
|
||||
}
|
||||
|
||||
// parseFlexDate tries multiple date formats and returns YYYY-MM-DD.
|
||||
func parseFlexDate(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
formats := []string{
|
||||
"2006-01-02",
|
||||
"02.01.2006", // DD.MM.YYYY (common in EU)
|
||||
"01/02/2006", // MM/DD/YYYY
|
||||
"02/01/2006", // DD/MM/YYYY
|
||||
"2006/01/02",
|
||||
"Jan 2, 2006",
|
||||
"2 Jan 2006",
|
||||
"02-01-2006",
|
||||
"01-02-2006",
|
||||
time.RFC3339,
|
||||
}
|
||||
for _, f := range formats {
|
||||
if t, err := time.Parse(f, s); err == nil {
|
||||
return t.Format("2006-01-02")
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseAmount parses a monetary amount string, handling comma/dot decimals and negative formats.
|
||||
func parseAmount(s string) float64 {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
// Remove currency symbols and whitespace
|
||||
s = strings.NewReplacer("€", "", "$", "", "£", "", " ", "", "\u00a0", "").Replace(s)
|
||||
|
||||
// Handle European format: 1.234,56 → 1234.56
|
||||
if strings.Contains(s, ",") && strings.Contains(s, ".") {
|
||||
if strings.LastIndex(s, ",") > strings.LastIndex(s, ".") {
|
||||
// comma is decimal: 1.234,56
|
||||
s = strings.ReplaceAll(s, ".", "")
|
||||
s = strings.ReplaceAll(s, ",", ".")
|
||||
} else {
|
||||
// dot is decimal: 1,234.56
|
||||
s = strings.ReplaceAll(s, ",", "")
|
||||
}
|
||||
} else if strings.Contains(s, ",") {
|
||||
// Only comma: assume decimal separator
|
||||
s = strings.ReplaceAll(s, ",", ".")
|
||||
}
|
||||
|
||||
v, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return v
|
||||
}
|
||||
241
pkg/server/bus.go
Normal file
241
pkg/server/bus.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Bus implements a simple long-polling message bus for Discuss.
|
||||
// Mirrors: odoo/addons/bus/models/bus.py ImBus
|
||||
//
|
||||
// Channels subscribe to notifications. A long-poll request blocks until
|
||||
// a notification arrives or the timeout expires.
|
||||
type Bus struct {
|
||||
mu sync.Mutex
|
||||
channels map[int64][]chan busNotification
|
||||
lastID int64
|
||||
}
|
||||
|
||||
type busNotification struct {
|
||||
ID int64 `json:"id"`
|
||||
Channel string `json:"channel"`
|
||||
Message interface{} `json:"message"`
|
||||
}
|
||||
|
||||
// NewBus creates a new message bus.
|
||||
func NewBus() *Bus {
|
||||
return &Bus{
|
||||
channels: make(map[int64][]chan busNotification),
|
||||
}
|
||||
}
|
||||
|
||||
// Notify sends a notification to all subscribers of a channel.
|
||||
func (b *Bus) Notify(channelID int64, channel string, message interface{}) {
|
||||
b.mu.Lock()
|
||||
b.lastID++
|
||||
notif := busNotification{
|
||||
ID: b.lastID,
|
||||
Channel: channel,
|
||||
Message: message,
|
||||
}
|
||||
subs := b.channels[channelID]
|
||||
b.mu.Unlock()
|
||||
|
||||
for _, ch := range subs {
|
||||
select {
|
||||
case ch <- notif:
|
||||
default:
|
||||
// subscriber buffer full, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe creates a subscription for a partner's channels.
|
||||
func (b *Bus) Subscribe(partnerID int64) chan busNotification {
|
||||
ch := make(chan busNotification, 10)
|
||||
b.mu.Lock()
|
||||
b.channels[partnerID] = append(b.channels[partnerID], ch)
|
||||
b.mu.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscription.
|
||||
func (b *Bus) Unsubscribe(partnerID int64, ch chan busNotification) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
subs := b.channels[partnerID]
|
||||
for i, s := range subs {
|
||||
if s == ch {
|
||||
b.channels[partnerID] = append(subs[:i], subs[i+1:]...)
|
||||
close(ch)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// registerBusRoutes adds the long-polling endpoint.
|
||||
func (s *Server) registerBusRoutes() {
|
||||
if s.bus == nil {
|
||||
s.bus = NewBus()
|
||||
}
|
||||
s.mux.HandleFunc("/longpolling/poll", s.handleBusPoll)
|
||||
s.mux.HandleFunc("/discuss/channel/messages", s.handleDiscussMessages)
|
||||
s.mux.HandleFunc("/discuss/channel/list", s.handleDiscussChannelList)
|
||||
}
|
||||
|
||||
// handleBusPoll implements long-polling for real-time notifications.
|
||||
// Mirrors: odoo/addons/bus/controllers/main.py poll()
|
||||
func (s *Server) handleBusPoll(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
sess := GetSession(r)
|
||||
if sess == nil {
|
||||
writeJSON(w, []interface{}{})
|
||||
return
|
||||
}
|
||||
|
||||
// Get partner ID
|
||||
var partnerID int64
|
||||
s.pool.QueryRow(r.Context(),
|
||||
`SELECT COALESCE(partner_id, 0) FROM res_users WHERE id = $1`, sess.UID,
|
||||
).Scan(&partnerID)
|
||||
|
||||
if partnerID == 0 {
|
||||
writeJSON(w, []interface{}{})
|
||||
return
|
||||
}
|
||||
|
||||
// Subscribe and wait for notifications (max 30s)
|
||||
ch := s.bus.Subscribe(partnerID)
|
||||
defer s.bus.Unsubscribe(partnerID, ch)
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case notif := <-ch:
|
||||
writeJSON(w, []busNotification{notif})
|
||||
case <-ctx.Done():
|
||||
writeJSON(w, []interface{}{}) // timeout, empty response
|
||||
}
|
||||
}
|
||||
|
||||
// handleDiscussMessages fetches messages for a channel via JSON-RPC.
|
||||
func (s *Server) handleDiscussMessages(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
sess := GetSession(r)
|
||||
if sess == nil {
|
||||
s.writeJSONRPC(w, nil, nil, &RPCError{Code: 100, Message: "Not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var req JSONRPCRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.writeJSONRPC(w, nil, nil, &RPCError{Code: -32700, Message: "Parse error"})
|
||||
return
|
||||
}
|
||||
|
||||
var params struct {
|
||||
ChannelID int64 `json:"channel_id"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "Invalid params"})
|
||||
return
|
||||
}
|
||||
if params.Limit <= 0 {
|
||||
params.Limit = 50
|
||||
}
|
||||
|
||||
rows, err := s.pool.Query(r.Context(),
|
||||
`SELECT m.id, m.body, m.date, m.author_id, COALESCE(p.name, '')
|
||||
FROM mail_message m
|
||||
LEFT JOIN res_partner p ON p.id = m.author_id
|
||||
WHERE m.model = 'mail.channel' AND m.res_id = $1
|
||||
ORDER BY m.id DESC LIMIT $2`, params.ChannelID, params.Limit)
|
||||
if err != nil {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32603, Message: fmt.Sprintf("Query: %v", err)})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var id, authorID int64
|
||||
var body, authorName string
|
||||
var date interface{}
|
||||
if err := rows.Scan(&id, &body, &date, &authorID, &authorName); err != nil {
|
||||
continue
|
||||
}
|
||||
msg := map[string]interface{}{
|
||||
"id": id, "body": body, "date": date,
|
||||
}
|
||||
if authorID > 0 {
|
||||
msg["author_id"] = []interface{}{authorID, authorName}
|
||||
} else {
|
||||
msg["author_id"] = false
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
if messages == nil {
|
||||
messages = []map[string]interface{}{}
|
||||
}
|
||||
s.writeJSONRPC(w, req.ID, messages, nil)
|
||||
}
|
||||
|
||||
// handleDiscussChannelList returns channels the current user is member of.
|
||||
func (s *Server) handleDiscussChannelList(w http.ResponseWriter, r *http.Request) {
|
||||
sess := GetSession(r)
|
||||
if sess == nil {
|
||||
s.writeJSONRPC(w, nil, nil, &RPCError{Code: 100, Message: "Not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var partnerID int64
|
||||
s.pool.QueryRow(r.Context(),
|
||||
`SELECT COALESCE(partner_id, 0) FROM res_users WHERE id = $1`, sess.UID,
|
||||
).Scan(&partnerID)
|
||||
|
||||
rows, err := s.pool.Query(r.Context(),
|
||||
`SELECT c.id, c.name, c.channel_type,
|
||||
(SELECT COUNT(*) FROM mail_channel_member WHERE channel_id = c.id) AS members
|
||||
FROM mail_channel c
|
||||
JOIN mail_channel_member cm ON cm.channel_id = c.id AND cm.partner_id = $1
|
||||
WHERE c.active = true
|
||||
ORDER BY c.last_message_date DESC NULLS LAST`, partnerID)
|
||||
if err != nil {
|
||||
log.Printf("discuss: channel list error: %v", err)
|
||||
writeJSON(w, []interface{}{})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var channels []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var name, channelType string
|
||||
var members int64
|
||||
if err := rows.Scan(&id, &name, &channelType, &members); err != nil {
|
||||
continue
|
||||
}
|
||||
channels = append(channels, map[string]interface{}{
|
||||
"id": id, "name": name, "channel_type": channelType, "member_count": members,
|
||||
})
|
||||
}
|
||||
if channels == nil {
|
||||
channels = []map[string]interface{}{}
|
||||
}
|
||||
writeJSON(w, channels)
|
||||
}
|
||||
@@ -6,37 +6,45 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/xuri/excelize/v2"
|
||||
"odoo-go/pkg/orm"
|
||||
)
|
||||
|
||||
// handleExportCSV exports records as CSV.
|
||||
// Mirrors: odoo/addons/web/controllers/export.py ExportController
|
||||
func (s *Server) handleExportCSV(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
// exportField describes a field in an export request.
|
||||
type exportField struct {
|
||||
Name string `json:"name"`
|
||||
Label string `json:"label"`
|
||||
}
|
||||
|
||||
// exportData holds the parsed and fetched data for an export operation.
|
||||
type exportData struct {
|
||||
Model string
|
||||
FieldNames []string
|
||||
Headers []string
|
||||
Records []orm.Values
|
||||
}
|
||||
|
||||
// parseExportRequest parses the common request/params/env/search logic shared by CSV and XLSX export.
|
||||
func (s *Server) parseExportRequest(w http.ResponseWriter, r *http.Request) (*exportData, error) {
|
||||
var req JSONRPCRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.writeJSONRPC(w, nil, nil, &RPCError{Code: -32700, Message: "Parse error"})
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var params struct {
|
||||
Data struct {
|
||||
Model string `json:"model"`
|
||||
Fields []exportField `json:"fields"`
|
||||
Domain []interface{} `json:"domain"`
|
||||
IDs []float64 `json:"ids"`
|
||||
Model string `json:"model"`
|
||||
Fields []exportField `json:"fields"`
|
||||
Domain []interface{} `json:"domain"`
|
||||
IDs []float64 `json:"ids"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "Invalid params"})
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract UID from session
|
||||
uid := int64(1)
|
||||
companyID := int64(1)
|
||||
if sess := GetSession(r); sess != nil {
|
||||
@@ -45,42 +53,31 @@ func (s *Server) handleExportCSV(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
env, err := orm.NewEnvironment(r.Context(), orm.EnvConfig{
|
||||
Pool: s.pool,
|
||||
UID: uid,
|
||||
CompanyID: companyID,
|
||||
Pool: s.pool, UID: uid, CompanyID: companyID,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, "Internal error", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
defer env.Close()
|
||||
|
||||
rs := env.Model(params.Data.Model)
|
||||
|
||||
// Determine which record IDs to export
|
||||
var ids []int64
|
||||
if len(params.Data.IDs) > 0 {
|
||||
for _, id := range params.Data.IDs {
|
||||
ids = append(ids, int64(id))
|
||||
}
|
||||
} else {
|
||||
// Search with domain
|
||||
domain := parseDomain([]interface{}{params.Data.Domain})
|
||||
found, err := rs.Search(domain, orm.SearchOpts{Limit: 10000})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
ids = found.IDs()
|
||||
}
|
||||
|
||||
if len(ids) == 0 {
|
||||
w.Header().Set("Content-Type", "text/csv")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s.csv", params.Data.Model))
|
||||
return
|
||||
}
|
||||
|
||||
// Extract field names
|
||||
var fieldNames []string
|
||||
var headers []string
|
||||
for _, f := range params.Data.Fields {
|
||||
@@ -92,42 +89,89 @@ func (s *Server) handleExportCSV(w http.ResponseWriter, r *http.Request) {
|
||||
headers = append(headers, label)
|
||||
}
|
||||
|
||||
// Read records
|
||||
records, err := rs.Browse(ids...).Read(fieldNames)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
var records []orm.Values
|
||||
if len(ids) > 0 {
|
||||
records, err = rs.Browse(ids...).Read(fieldNames)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := env.Commit(); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &exportData{
|
||||
Model: params.Data.Model, FieldNames: fieldNames, Headers: headers, Records: records,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleExportCSV exports records as CSV.
|
||||
// Mirrors: odoo/addons/web/controllers/export.py ExportController
|
||||
func (s *Server) handleExportCSV(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := s.parseExportRequest(w, r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Write CSV
|
||||
w.Header().Set("Content-Type", "text/csv; charset=utf-8")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s.csv", params.Data.Model))
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s.csv", data.Model))
|
||||
|
||||
writer := csv.NewWriter(w)
|
||||
defer writer.Flush()
|
||||
|
||||
// Header row
|
||||
writer.Write(headers)
|
||||
|
||||
// Data rows
|
||||
for _, rec := range records {
|
||||
row := make([]string, len(fieldNames))
|
||||
for i, fname := range fieldNames {
|
||||
writer.Write(data.Headers)
|
||||
for _, rec := range data.Records {
|
||||
row := make([]string, len(data.FieldNames))
|
||||
for i, fname := range data.FieldNames {
|
||||
row[i] = formatCSVValue(rec[fname])
|
||||
}
|
||||
writer.Write(row)
|
||||
}
|
||||
}
|
||||
|
||||
// exportField describes a field in an export request.
|
||||
type exportField struct {
|
||||
Name string `json:"name"`
|
||||
Label string `json:"label"`
|
||||
// handleExportXLSX exports records as XLSX (Excel).
|
||||
// Mirrors: odoo/addons/web/controllers/export.py ExportXlsxController
|
||||
func (s *Server) handleExportXLSX(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := s.parseExportRequest(w, r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
f := excelize.NewFile()
|
||||
sheet := "Sheet1"
|
||||
|
||||
headerStyle, _ := f.NewStyle(&excelize.Style{
|
||||
Font: &excelize.Font{Bold: true},
|
||||
})
|
||||
for i, h := range data.Headers {
|
||||
cell, _ := excelize.CoordinatesToCellName(i+1, 1)
|
||||
f.SetCellValue(sheet, cell, h)
|
||||
f.SetCellStyle(sheet, cell, cell, headerStyle)
|
||||
}
|
||||
|
||||
for rowIdx, rec := range data.Records {
|
||||
for colIdx, fname := range data.FieldNames {
|
||||
cell, _ := excelize.CoordinatesToCellName(colIdx+1, rowIdx+2)
|
||||
f.SetCellValue(sheet, cell, formatCSVValue(rec[fname]))
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s.xlsx", data.Model))
|
||||
f.Write(w)
|
||||
}
|
||||
|
||||
// formatCSVValue converts a field value to a CSV string.
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -55,9 +56,11 @@ func (s *Server) handleImage(w http.ResponseWriter, r *http.Request) {
|
||||
table := m.Table()
|
||||
var data []byte
|
||||
ctx := r.Context()
|
||||
_ = s.pool.QueryRow(ctx,
|
||||
if err := s.pool.QueryRow(ctx,
|
||||
fmt.Sprintf(`SELECT "%s" FROM "%s" WHERE id = $1`, f.Column(), table), id,
|
||||
).Scan(&data)
|
||||
).Scan(&data); err != nil {
|
||||
log.Printf("warning: image query failed for %s.%s id=%d: %v", model, field, id, err)
|
||||
}
|
||||
if len(data) > 0 {
|
||||
// Detect content type
|
||||
contentType := http.DetectContentType(data)
|
||||
@@ -76,9 +79,11 @@ func (s *Server) handleImage(w http.ResponseWriter, r *http.Request) {
|
||||
m := orm.Registry.Get(model)
|
||||
if m != nil {
|
||||
var name string
|
||||
_ = s.pool.QueryRow(r.Context(),
|
||||
if err := s.pool.QueryRow(r.Context(),
|
||||
fmt.Sprintf(`SELECT COALESCE(name, '') FROM "%s" WHERE id = $1`, m.Table()), id,
|
||||
).Scan(&name)
|
||||
).Scan(&name); err != nil {
|
||||
log.Printf("warning: image name lookup failed for %s id=%d: %v", model, id, err)
|
||||
}
|
||||
if len(name) > 0 {
|
||||
initial = strings.ToUpper(name[:1])
|
||||
}
|
||||
|
||||
223
pkg/server/import.go
Normal file
223
pkg/server/import.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"odoo-go/pkg/orm"
|
||||
)
|
||||
|
||||
// handleImportCSV imports records from a CSV file into any model.
|
||||
// Accepts JSON body with: model, fields (mapping), csv_data (raw CSV string).
|
||||
// Mirrors: odoo/addons/base_import/controllers/main.py ImportController
|
||||
func (s *Server) handleImportCSV(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req JSONRPCRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.writeJSONRPC(w, nil, nil, &RPCError{Code: -32700, Message: "Parse error"})
|
||||
return
|
||||
}
|
||||
|
||||
var params struct {
|
||||
Model string `json:"model"`
|
||||
Fields []importFieldMap `json:"fields"`
|
||||
CSVData string `json:"csv_data"`
|
||||
HasHeader bool `json:"has_header"`
|
||||
DryRun bool `json:"dry_run"`
|
||||
}
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "Invalid params"})
|
||||
return
|
||||
}
|
||||
|
||||
if params.Model == "" || len(params.Fields) == 0 || params.CSVData == "" {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "model, fields, and csv_data are required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify model exists
|
||||
m := orm.Registry.Get(params.Model)
|
||||
if m == nil {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: fmt.Sprintf("Unknown model: %s", params.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse CSV
|
||||
reader := csv.NewReader(strings.NewReader(params.CSVData))
|
||||
reader.LazyQuotes = true
|
||||
reader.TrimLeadingSpace = true
|
||||
|
||||
var allRows [][]string
|
||||
for {
|
||||
row, err := reader.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: fmt.Sprintf("CSV parse error: %v", err)})
|
||||
return
|
||||
}
|
||||
allRows = append(allRows, row)
|
||||
}
|
||||
|
||||
if len(allRows) == 0 {
|
||||
s.writeJSONRPC(w, req.ID, map[string]interface{}{"ids": []int64{}, "count": 0}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip header row if present
|
||||
dataRows := allRows
|
||||
if params.HasHeader && len(allRows) > 1 {
|
||||
dataRows = allRows[1:]
|
||||
}
|
||||
|
||||
// Build field mapping: CSV column index → ORM field name
|
||||
type colMapping struct {
|
||||
colIndex int
|
||||
fieldName string
|
||||
fieldType orm.FieldType
|
||||
}
|
||||
var mappings []colMapping
|
||||
for _, fm := range params.Fields {
|
||||
if fm.FieldName == "" || fm.ColumnIndex < 0 {
|
||||
continue
|
||||
}
|
||||
f := m.GetField(fm.FieldName)
|
||||
if f == nil {
|
||||
continue // skip unknown fields
|
||||
}
|
||||
mappings = append(mappings, colMapping{
|
||||
colIndex: fm.ColumnIndex,
|
||||
fieldName: fm.FieldName,
|
||||
fieldType: f.Type,
|
||||
})
|
||||
}
|
||||
|
||||
if len(mappings) == 0 {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "No valid field mappings"})
|
||||
return
|
||||
}
|
||||
|
||||
uid := int64(1)
|
||||
companyID := int64(1)
|
||||
if sess := GetSession(r); sess != nil {
|
||||
uid = sess.UID
|
||||
companyID = sess.CompanyID
|
||||
}
|
||||
|
||||
env, err := orm.NewEnvironment(r.Context(), orm.EnvConfig{
|
||||
Pool: s.pool,
|
||||
UID: uid,
|
||||
CompanyID: companyID,
|
||||
})
|
||||
if err != nil {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32603, Message: "Internal error"})
|
||||
return
|
||||
}
|
||||
defer env.Close()
|
||||
|
||||
rs := env.Model(params.Model)
|
||||
|
||||
var createdIDs []int64
|
||||
var errors []importError
|
||||
|
||||
for rowIdx, row := range dataRows {
|
||||
vals := make(orm.Values)
|
||||
for _, cm := range mappings {
|
||||
if cm.colIndex >= len(row) {
|
||||
continue
|
||||
}
|
||||
raw := strings.TrimSpace(row[cm.colIndex])
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
vals[cm.fieldName] = coerceImportValue(raw, cm.fieldType)
|
||||
}
|
||||
|
||||
if len(vals) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if params.DryRun {
|
||||
continue // validate only, don't create
|
||||
}
|
||||
|
||||
rec, err := rs.Create(vals)
|
||||
if err != nil {
|
||||
errors = append(errors, importError{
|
||||
Row: rowIdx + 1,
|
||||
Message: err.Error(),
|
||||
})
|
||||
log.Printf("import: row %d error: %v", rowIdx+1, err)
|
||||
continue
|
||||
}
|
||||
createdIDs = append(createdIDs, rec.ID())
|
||||
}
|
||||
|
||||
if err := env.Commit(); err != nil {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32603, Message: fmt.Sprintf("Commit error: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"ids": createdIDs,
|
||||
"count": len(createdIDs),
|
||||
"errors": errors,
|
||||
"dry_run": params.DryRun,
|
||||
}
|
||||
s.writeJSONRPC(w, req.ID, result, nil)
|
||||
}
|
||||
|
||||
// importFieldMap maps a CSV column to an ORM field.
|
||||
type importFieldMap struct {
|
||||
ColumnIndex int `json:"column_index"`
|
||||
FieldName string `json:"field_name"`
|
||||
}
|
||||
|
||||
// importError describes a per-row import error.
|
||||
type importError struct {
|
||||
Row int `json:"row"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// coerceImportValue converts a raw CSV string to the appropriate Go type for ORM Create.
|
||||
func coerceImportValue(raw string, ft orm.FieldType) interface{} {
|
||||
switch ft {
|
||||
case orm.TypeInteger:
|
||||
v, err := strconv.ParseInt(raw, 10, 64)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
case orm.TypeFloat, orm.TypeMonetary:
|
||||
// Handle comma as decimal separator
|
||||
raw = strings.ReplaceAll(raw, ",", ".")
|
||||
v, err := strconv.ParseFloat(raw, 64)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
case orm.TypeBoolean:
|
||||
lower := strings.ToLower(raw)
|
||||
return lower == "true" || lower == "1" || lower == "yes" || lower == "ja"
|
||||
case orm.TypeMany2one:
|
||||
// Try as integer ID first, then as name_search later
|
||||
v, err := strconv.ParseInt(raw, 10, 64)
|
||||
if err != nil {
|
||||
return raw // pass as string, ORM may handle name_create
|
||||
}
|
||||
return v
|
||||
default:
|
||||
return raw
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -43,13 +44,19 @@ func (w *statusWriter) WriteHeader(code int) {
|
||||
func AuthMiddleware(store *SessionStore, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Public endpoints (no auth required)
|
||||
path := r.URL.Path
|
||||
path := filepath.Clean(r.URL.Path)
|
||||
if path == "/health" ||
|
||||
path == "/web/login" ||
|
||||
path == "/web/session/authenticate" ||
|
||||
path == "/web/session/logout" ||
|
||||
strings.HasPrefix(path, "/web/database/") ||
|
||||
path == "/web/database/manager" ||
|
||||
path == "/web/database/create" ||
|
||||
path == "/web/database/list" ||
|
||||
path == "/web/webclient/version_info" ||
|
||||
path == "/web/setup/wizard" ||
|
||||
path == "/web/setup/wizard/save" ||
|
||||
path == "/web/portal/signup" ||
|
||||
path == "/web/portal/reset_password" ||
|
||||
strings.Contains(path, "/static/") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
@@ -58,8 +65,14 @@ func AuthMiddleware(store *SessionStore, next http.Handler) http.Handler {
|
||||
// Check session cookie
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil || cookie.Value == "" {
|
||||
// Also check JSON-RPC params for session_id (Odoo sends it both ways)
|
||||
next.ServeHTTP(w, r) // For now, allow through — UID defaults to 1
|
||||
// No session cookie — reject protected endpoints
|
||||
if r.Header.Get("Content-Type") == "application/json" ||
|
||||
strings.HasPrefix(path, "/web/dataset/") ||
|
||||
strings.HasPrefix(path, "/jsonrpc") {
|
||||
http.Error(w, `{"jsonrpc":"2.0","error":{"code":100,"message":"Session expired"}}`, http.StatusUnauthorized)
|
||||
} else {
|
||||
http.Redirect(w, r, "/web/login", http.StatusFound)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
379
pkg/server/portal.go
Normal file
379
pkg/server/portal.go
Normal file
@@ -0,0 +1,379 @@
|
||||
// Package server — Portal controllers for external (customer/supplier) access.
|
||||
// Mirrors: odoo/addons/portal/controllers/portal.py CustomerPortal
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// registerPortalRoutes registers all /my/* portal endpoints.
|
||||
func (s *Server) registerPortalRoutes() {
|
||||
s.mux.HandleFunc("/my", s.handlePortalHome)
|
||||
s.mux.HandleFunc("/my/", s.handlePortalDispatch)
|
||||
s.mux.HandleFunc("/my/home", s.handlePortalHome)
|
||||
s.mux.HandleFunc("/my/invoices", s.handlePortalInvoices)
|
||||
s.mux.HandleFunc("/my/orders", s.handlePortalOrders)
|
||||
s.mux.HandleFunc("/my/pickings", s.handlePortalPickings)
|
||||
s.mux.HandleFunc("/my/account", s.handlePortalAccount)
|
||||
}
|
||||
|
||||
// handlePortalDispatch routes /my/* sub-paths to the correct handler.
|
||||
func (s *Server) handlePortalDispatch(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/my/home":
|
||||
s.handlePortalHome(w, r)
|
||||
case "/my/invoices":
|
||||
s.handlePortalInvoices(w, r)
|
||||
case "/my/orders":
|
||||
s.handlePortalOrders(w, r)
|
||||
case "/my/pickings":
|
||||
s.handlePortalPickings(w, r)
|
||||
case "/my/account":
|
||||
s.handlePortalAccount(w, r)
|
||||
default:
|
||||
s.handlePortalHome(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// portalPartnerID resolves the partner_id of the currently logged-in portal user.
|
||||
// Returns (partnerID, error). If session is missing, writes an error response and returns 0.
|
||||
func (s *Server) portalPartnerID(w http.ResponseWriter, r *http.Request) (int64, bool) {
|
||||
sess := GetSession(r)
|
||||
if sess == nil {
|
||||
writePortalError(w, http.StatusUnauthorized, "Not authenticated")
|
||||
return 0, false
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var partnerID int64
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT partner_id FROM res_users WHERE id = $1 AND active = true`,
|
||||
sess.UID).Scan(&partnerID)
|
||||
if err != nil {
|
||||
log.Printf("portal: cannot resolve partner_id for uid=%d: %v", sess.UID, err)
|
||||
writePortalError(w, http.StatusForbidden, "User not found")
|
||||
return 0, false
|
||||
}
|
||||
return partnerID, true
|
||||
}
|
||||
|
||||
// handlePortalHome returns the portal dashboard with document counts.
|
||||
// Mirrors: odoo/addons/portal/controllers/portal.py CustomerPortal.home()
|
||||
func (s *Server) handlePortalHome(w http.ResponseWriter, r *http.Request) {
|
||||
partnerID, ok := s.portalPartnerID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var invoiceCount, orderCount, pickingCount int64
|
||||
|
||||
// Count invoices (account.move with move_type in ('out_invoice','out_refund'))
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT COUNT(*) FROM account_move
|
||||
WHERE partner_id = $1 AND move_type IN ('out_invoice','out_refund')
|
||||
AND state = 'posted'`, partnerID).Scan(&invoiceCount)
|
||||
if err != nil {
|
||||
log.Printf("portal: invoice count error: %v", err)
|
||||
}
|
||||
|
||||
// Count sale orders (confirmed or done)
|
||||
err = s.pool.QueryRow(ctx,
|
||||
`SELECT COUNT(*) FROM sale_order
|
||||
WHERE partner_id = $1 AND state IN ('sale','done')`, partnerID).Scan(&orderCount)
|
||||
if err != nil {
|
||||
log.Printf("portal: order count error: %v", err)
|
||||
}
|
||||
|
||||
// Count pickings (stock.picking)
|
||||
err = s.pool.QueryRow(ctx,
|
||||
`SELECT COUNT(*) FROM stock_picking
|
||||
WHERE partner_id = $1 AND state != 'cancel'`, partnerID).Scan(&pickingCount)
|
||||
if err != nil {
|
||||
log.Printf("portal: picking count error: %v", err)
|
||||
}
|
||||
|
||||
writePortalJSON(w, map[string]interface{}{
|
||||
"counters": map[string]int64{
|
||||
"invoice_count": invoiceCount,
|
||||
"order_count": orderCount,
|
||||
"picking_count": pickingCount,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// handlePortalInvoices lists invoices for the current portal user.
|
||||
// Mirrors: odoo/addons/portal/controllers/portal.py CustomerPortal.portal_my_invoices()
|
||||
func (s *Server) handlePortalInvoices(w http.ResponseWriter, r *http.Request) {
|
||||
partnerID, ok := s.portalPartnerID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
rows, err := s.pool.Query(ctx,
|
||||
`SELECT m.id, m.name, m.move_type, m.state, m.date,
|
||||
m.amount_total::float8, m.amount_residual::float8,
|
||||
m.payment_state, COALESCE(m.ref, '')
|
||||
FROM account_move m
|
||||
WHERE m.partner_id = $1
|
||||
AND m.move_type IN ('out_invoice','out_refund')
|
||||
AND m.state = 'posted'
|
||||
ORDER BY m.date DESC
|
||||
LIMIT 80`, partnerID)
|
||||
if err != nil {
|
||||
log.Printf("portal: invoice query error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Failed to load invoices")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var invoices []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var name, moveType, state, paymentState, ref string
|
||||
var date time.Time
|
||||
var amountTotal, amountResidual float64
|
||||
if err := rows.Scan(&id, &name, &moveType, &state, &date,
|
||||
&amountTotal, &amountResidual, &paymentState, &ref); err != nil {
|
||||
log.Printf("portal: invoice scan error: %v", err)
|
||||
continue
|
||||
}
|
||||
invoices = append(invoices, map[string]interface{}{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"move_type": moveType,
|
||||
"state": state,
|
||||
"date": date.Format("2006-01-02"),
|
||||
"amount_total": amountTotal,
|
||||
"amount_residual": amountResidual,
|
||||
"payment_state": paymentState,
|
||||
"ref": ref,
|
||||
})
|
||||
}
|
||||
if invoices == nil {
|
||||
invoices = []map[string]interface{}{}
|
||||
}
|
||||
writePortalJSON(w, map[string]interface{}{"invoices": invoices})
|
||||
}
|
||||
|
||||
// handlePortalOrders lists sale orders for the current portal user.
|
||||
// Mirrors: odoo/addons/portal/controllers/portal.py CustomerPortal.portal_my_orders()
|
||||
func (s *Server) handlePortalOrders(w http.ResponseWriter, r *http.Request) {
|
||||
partnerID, ok := s.portalPartnerID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
rows, err := s.pool.Query(ctx,
|
||||
`SELECT so.id, so.name, so.state, so.date_order,
|
||||
so.amount_total::float8, COALESCE(so.invoice_status, ''),
|
||||
COALESCE(so.delivery_status, '')
|
||||
FROM sale_order so
|
||||
WHERE so.partner_id = $1
|
||||
AND so.state IN ('sale','done')
|
||||
ORDER BY so.date_order DESC
|
||||
LIMIT 80`, partnerID)
|
||||
if err != nil {
|
||||
log.Printf("portal: order query error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Failed to load orders")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var orders []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var name, state, invoiceStatus, deliveryStatus string
|
||||
var dateOrder time.Time
|
||||
var amountTotal float64
|
||||
if err := rows.Scan(&id, &name, &state, &dateOrder,
|
||||
&amountTotal, &invoiceStatus, &deliveryStatus); err != nil {
|
||||
log.Printf("portal: order scan error: %v", err)
|
||||
continue
|
||||
}
|
||||
orders = append(orders, map[string]interface{}{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"state": state,
|
||||
"date_order": dateOrder.Format("2006-01-02 15:04:05"),
|
||||
"amount_total": amountTotal,
|
||||
"invoice_status": invoiceStatus,
|
||||
"delivery_status": deliveryStatus,
|
||||
})
|
||||
}
|
||||
if orders == nil {
|
||||
orders = []map[string]interface{}{}
|
||||
}
|
||||
writePortalJSON(w, map[string]interface{}{"orders": orders})
|
||||
}
|
||||
|
||||
// handlePortalPickings lists stock pickings for the current portal user.
|
||||
// Mirrors: odoo/addons/portal/controllers/portal.py CustomerPortal.portal_my_pickings()
|
||||
func (s *Server) handlePortalPickings(w http.ResponseWriter, r *http.Request) {
|
||||
partnerID, ok := s.portalPartnerID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
rows, err := s.pool.Query(ctx,
|
||||
`SELECT sp.id, sp.name, sp.state, sp.scheduled_date,
|
||||
COALESCE(sp.origin, ''),
|
||||
COALESCE(spt.name, '') AS picking_type_name
|
||||
FROM stock_picking sp
|
||||
LEFT JOIN stock_picking_type spt ON spt.id = sp.picking_type_id
|
||||
WHERE sp.partner_id = $1
|
||||
AND sp.state != 'cancel'
|
||||
ORDER BY sp.scheduled_date DESC
|
||||
LIMIT 80`, partnerID)
|
||||
if err != nil {
|
||||
log.Printf("portal: picking query error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Failed to load pickings")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var pickings []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var name, state, origin, pickingTypeName string
|
||||
var scheduledDate time.Time
|
||||
if err := rows.Scan(&id, &name, &state, &scheduledDate,
|
||||
&origin, &pickingTypeName); err != nil {
|
||||
log.Printf("portal: picking scan error: %v", err)
|
||||
continue
|
||||
}
|
||||
pickings = append(pickings, map[string]interface{}{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"state": state,
|
||||
"scheduled_date": scheduledDate.Format("2006-01-02 15:04:05"),
|
||||
"origin": origin,
|
||||
"picking_type_name": pickingTypeName,
|
||||
})
|
||||
}
|
||||
if pickings == nil {
|
||||
pickings = []map[string]interface{}{}
|
||||
}
|
||||
writePortalJSON(w, map[string]interface{}{"pickings": pickings})
|
||||
}
|
||||
|
||||
// handlePortalAccount returns/updates the portal user's profile.
|
||||
// GET: returns user profile. POST: updates name/email/phone/street/city/zip.
|
||||
// Mirrors: odoo/addons/portal/controllers/portal.py CustomerPortal.account()
|
||||
func (s *Server) handlePortalAccount(w http.ResponseWriter, r *http.Request) {
|
||||
partnerID, ok := s.portalPartnerID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if r.Method == http.MethodPost {
|
||||
// Update profile
|
||||
var body struct {
|
||||
Name *string `json:"name"`
|
||||
Email *string `json:"email"`
|
||||
Phone *string `json:"phone"`
|
||||
Street *string `json:"street"`
|
||||
City *string `json:"city"`
|
||||
Zip *string `json:"zip"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writePortalError(w, http.StatusBadRequest, "Invalid JSON")
|
||||
return
|
||||
}
|
||||
|
||||
// Build SET clause dynamically with parameterized placeholders
|
||||
sets := make([]string, 0, 6)
|
||||
args := make([]interface{}, 0, 7)
|
||||
idx := 1
|
||||
addField := func(col string, val *string) {
|
||||
if val != nil {
|
||||
sets = append(sets, fmt.Sprintf("%s = $%d", col, idx))
|
||||
args = append(args, *val)
|
||||
idx++
|
||||
}
|
||||
}
|
||||
addField("name", body.Name)
|
||||
addField("email", body.Email)
|
||||
addField("phone", body.Phone)
|
||||
addField("street", body.Street)
|
||||
addField("city", body.City)
|
||||
addField("zip", body.Zip)
|
||||
|
||||
if len(sets) > 0 {
|
||||
args = append(args, partnerID)
|
||||
query := "UPDATE res_partner SET "
|
||||
for j, set := range sets {
|
||||
if j > 0 {
|
||||
query += ", "
|
||||
}
|
||||
query += set
|
||||
}
|
||||
query += fmt.Sprintf(" WHERE id = $%d", idx)
|
||||
if _, err := s.pool.Exec(ctx, query, args...); err != nil {
|
||||
log.Printf("portal: account update error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Update failed")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
writePortalJSON(w, map[string]interface{}{"success": true})
|
||||
return
|
||||
}
|
||||
|
||||
// GET — return profile
|
||||
var name, email, phone, street, city, zip string
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT COALESCE(name,''), COALESCE(email,''), COALESCE(phone,''),
|
||||
COALESCE(street,''), COALESCE(city,''), COALESCE(zip,'')
|
||||
FROM res_partner WHERE id = $1`, partnerID).Scan(
|
||||
&name, &email, &phone, &street, &city, &zip)
|
||||
if err != nil {
|
||||
log.Printf("portal: account read error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Failed to load profile")
|
||||
return
|
||||
}
|
||||
|
||||
writePortalJSON(w, map[string]interface{}{
|
||||
"name": name,
|
||||
"email": email,
|
||||
"phone": phone,
|
||||
"street": street,
|
||||
"city": city,
|
||||
"zip": zip,
|
||||
})
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func writePortalJSON(w http.ResponseWriter, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
func writePortalError(w http.ResponseWriter, status int, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": message})
|
||||
}
|
||||
313
pkg/server/portal_signup.go
Normal file
313
pkg/server/portal_signup.go
Normal file
@@ -0,0 +1,313 @@
|
||||
// Package server — Portal signup and password reset.
|
||||
// Mirrors: odoo/addons/auth_signup/controllers/main.py AuthSignupHome
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"odoo-go/pkg/tools"
|
||||
)
|
||||
|
||||
// registerPortalSignupRoutes registers /web/portal/* public endpoints.
|
||||
func (s *Server) registerPortalSignupRoutes() {
|
||||
s.mux.HandleFunc("/web/portal/signup", s.handlePortalSignup)
|
||||
s.mux.HandleFunc("/web/portal/reset_password", s.handlePortalResetPassword)
|
||||
}
|
||||
|
||||
// handlePortalSignup creates a new portal user with share=true and a matching res.partner.
|
||||
// Mirrors: odoo/addons/auth_signup/controllers/main.py AuthSignupHome.web_auth_signup()
|
||||
func (s *Server) handlePortalSignup(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writePortalError(w, http.StatusMethodNotAllowed, "POST required")
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writePortalError(w, http.StatusBadRequest, "Invalid JSON")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
body.Name = strings.TrimSpace(body.Name)
|
||||
body.Email = strings.TrimSpace(body.Email)
|
||||
if body.Name == "" || body.Email == "" || body.Password == "" {
|
||||
writePortalError(w, http.StatusBadRequest, "Name, email, and password are required")
|
||||
return
|
||||
}
|
||||
if len(body.Password) < 8 {
|
||||
writePortalError(w, http.StatusBadRequest, "Password must be at least 8 characters")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Check if login already exists
|
||||
var exists bool
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM res_users WHERE login = $1)`, body.Email).Scan(&exists)
|
||||
if err != nil {
|
||||
log.Printf("portal signup: check existing user error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Internal error")
|
||||
return
|
||||
}
|
||||
if exists {
|
||||
writePortalError(w, http.StatusConflict, "An account with this email already exists")
|
||||
return
|
||||
}
|
||||
|
||||
// Hash password
|
||||
hashedPw, err := tools.HashPassword(body.Password)
|
||||
if err != nil {
|
||||
log.Printf("portal signup: hash password error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Internal error")
|
||||
return
|
||||
}
|
||||
|
||||
// Get default company
|
||||
var companyID int64
|
||||
err = s.pool.QueryRow(ctx,
|
||||
`SELECT id FROM res_company WHERE active = true ORDER BY id LIMIT 1`).Scan(&companyID)
|
||||
if err != nil {
|
||||
log.Printf("portal signup: get company error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Internal error")
|
||||
return
|
||||
}
|
||||
|
||||
// Begin transaction — create partner + user atomically
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
log.Printf("portal signup: begin tx error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Internal error")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// Create res.partner
|
||||
var partnerID int64
|
||||
err = tx.QueryRow(ctx,
|
||||
`INSERT INTO res_partner (name, email, active, company_id, customer_rank)
|
||||
VALUES ($1, $2, true, $3, 1)
|
||||
RETURNING id`, body.Name, body.Email, companyID).Scan(&partnerID)
|
||||
if err != nil {
|
||||
log.Printf("portal signup: create partner error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Failed to create account")
|
||||
return
|
||||
}
|
||||
|
||||
// Create res.users with share=true
|
||||
var userID int64
|
||||
err = tx.QueryRow(ctx,
|
||||
`INSERT INTO res_users (login, password, active, partner_id, company_id, share)
|
||||
VALUES ($1, $2, true, $3, $4, true)
|
||||
RETURNING id`, body.Email, hashedPw, partnerID, companyID).Scan(&userID)
|
||||
if err != nil {
|
||||
log.Printf("portal signup: create user error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Failed to create account")
|
||||
return
|
||||
}
|
||||
|
||||
// Add user to group_portal (not group_user)
|
||||
var groupPortalID int64
|
||||
err = tx.QueryRow(ctx,
|
||||
`SELECT g.id FROM res_groups g
|
||||
JOIN ir_model_data imd ON imd.res_id = g.id AND imd.model = 'res.groups'
|
||||
WHERE imd.module = 'base' AND imd.name = 'group_portal'`).Scan(&groupPortalID)
|
||||
if err != nil {
|
||||
// group_portal might not exist yet — create it
|
||||
err = tx.QueryRow(ctx,
|
||||
`INSERT INTO res_groups (name) VALUES ('Portal') RETURNING id`).Scan(&groupPortalID)
|
||||
if err != nil {
|
||||
log.Printf("portal signup: create group_portal error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Failed to create account")
|
||||
return
|
||||
}
|
||||
_, err = tx.Exec(ctx,
|
||||
`INSERT INTO ir_model_data (module, name, model, res_id)
|
||||
VALUES ('base', 'group_portal', 'res.groups', $1)
|
||||
ON CONFLICT DO NOTHING`, groupPortalID)
|
||||
if err != nil {
|
||||
log.Printf("portal signup: create group_portal xmlid error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx,
|
||||
`INSERT INTO res_groups_res_users_rel (res_groups_id, res_users_id)
|
||||
VALUES ($1, $2) ON CONFLICT DO NOTHING`, groupPortalID, userID)
|
||||
if err != nil {
|
||||
log.Printf("portal signup: add user to group_portal error: %v", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
log.Printf("portal signup: commit error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Failed to create account")
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("portal signup: created portal user id=%d login=%s partner_id=%d",
|
||||
userID, body.Email, partnerID)
|
||||
|
||||
writePortalJSON(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"user_id": userID,
|
||||
"partner_id": partnerID,
|
||||
"message": "Account created successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// handlePortalResetPassword handles password reset requests.
|
||||
// POST with {"email":"..."}: generates a reset token and sends an email.
|
||||
// POST with {"token":"...","password":"..."}: resets the password.
|
||||
// Mirrors: odoo/addons/auth_signup/controllers/main.py AuthSignupHome.web_auth_reset_password()
|
||||
func (s *Server) handlePortalResetPassword(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writePortalError(w, http.StatusMethodNotAllowed, "POST required")
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Email string `json:"email"`
|
||||
Token string `json:"token"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writePortalError(w, http.StatusBadRequest, "Invalid JSON")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Phase 2: Token + new password → reset
|
||||
if body.Token != "" && body.Password != "" {
|
||||
s.handleResetWithToken(w, ctx, body.Token, body.Password)
|
||||
return
|
||||
}
|
||||
|
||||
// Phase 1: Email → generate token + send email
|
||||
if body.Email == "" {
|
||||
writePortalError(w, http.StatusBadRequest, "Email is required")
|
||||
return
|
||||
}
|
||||
|
||||
s.handleResetRequest(w, ctx, strings.TrimSpace(body.Email))
|
||||
}
|
||||
|
||||
// handleResetRequest generates a reset token and sends it via email.
|
||||
func (s *Server) handleResetRequest(w http.ResponseWriter, ctx context.Context, email string) {
|
||||
// Look up user
|
||||
var uid int64
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT id FROM res_users WHERE login = $1 AND active = true`, email).Scan(&uid)
|
||||
if err != nil {
|
||||
// Don't reveal whether the email exists — always return success
|
||||
writePortalJSON(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "If an account exists with this email, a reset link has been sent",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Generate token
|
||||
tokenBytes := make([]byte, 32)
|
||||
rand.Read(tokenBytes)
|
||||
token := hex.EncodeToString(tokenBytes)
|
||||
expiration := time.Now().Add(24 * time.Hour)
|
||||
|
||||
// Store token
|
||||
_, err = s.pool.Exec(ctx,
|
||||
`UPDATE res_users SET signup_token = $1, signup_expiration = $2 WHERE id = $3`,
|
||||
token, expiration, uid)
|
||||
if err != nil {
|
||||
log.Printf("portal reset: store token error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Internal error")
|
||||
return
|
||||
}
|
||||
|
||||
// Send email with reset link
|
||||
smtpCfg := tools.LoadSMTPConfig()
|
||||
resetURL := fmt.Sprintf("/web/portal/reset_password?token=%s", token)
|
||||
emailBody := fmt.Sprintf(`<html><body>
|
||||
<p>A password reset was requested for your account.</p>
|
||||
<p>Click the link below to set a new password:</p>
|
||||
<p><a href="%s">Reset Password</a></p>
|
||||
<p>This link expires in 24 hours.</p>
|
||||
<p>If you did not request this, you can ignore this email.</p>
|
||||
</body></html>`, resetURL)
|
||||
|
||||
if err := tools.SendEmail(smtpCfg, email, "Password Reset", emailBody); err != nil {
|
||||
log.Printf("portal reset: send email error: %v", err)
|
||||
// Don't expose email sending errors to the user
|
||||
}
|
||||
|
||||
writePortalJSON(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "If an account exists with this email, a reset link has been sent",
|
||||
})
|
||||
}
|
||||
|
||||
// handleResetWithToken validates the token and sets the new password.
|
||||
func (s *Server) handleResetWithToken(w http.ResponseWriter, ctx context.Context, token, password string) {
|
||||
if len(password) < 8 {
|
||||
writePortalError(w, http.StatusBadRequest, "Password must be at least 8 characters")
|
||||
return
|
||||
}
|
||||
|
||||
// Look up user by token
|
||||
var uid int64
|
||||
var expiration time.Time
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT id, signup_expiration FROM res_users
|
||||
WHERE signup_token = $1 AND active = true`, token).Scan(&uid, &expiration)
|
||||
if err != nil {
|
||||
writePortalError(w, http.StatusBadRequest, "Invalid or expired reset token")
|
||||
return
|
||||
}
|
||||
|
||||
// Check expiration
|
||||
if time.Now().After(expiration) {
|
||||
// Clear expired token
|
||||
s.pool.Exec(ctx,
|
||||
`UPDATE res_users SET signup_token = NULL, signup_expiration = NULL WHERE id = $1`, uid)
|
||||
writePortalError(w, http.StatusBadRequest, "Reset token has expired")
|
||||
return
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
hashedPw, err := tools.HashPassword(password)
|
||||
if err != nil {
|
||||
log.Printf("portal reset: hash password error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Internal error")
|
||||
return
|
||||
}
|
||||
|
||||
// Update password and clear token
|
||||
_, err = s.pool.Exec(ctx,
|
||||
`UPDATE res_users SET password = $1, signup_token = NULL, signup_expiration = NULL
|
||||
WHERE id = $2`, hashedPw, uid)
|
||||
if err != nil {
|
||||
log.Printf("portal reset: update password error: %v", err)
|
||||
writePortalError(w, http.StatusInternalServerError, "Failed to reset password")
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("portal reset: password reset for uid=%d", uid)
|
||||
|
||||
writePortalJSON(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Password has been reset successfully",
|
||||
})
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
@@ -35,6 +36,8 @@ type Server struct {
|
||||
// all JS files (except module_loader.js) plus the XML template bundle,
|
||||
// served as a single file to avoid hundreds of individual HTTP requests.
|
||||
jsBundle string
|
||||
|
||||
bus *Bus // Message bus for Discuss long-polling
|
||||
}
|
||||
|
||||
// New creates a new server instance.
|
||||
@@ -128,6 +131,17 @@ func (s *Server) registerRoutes() {
|
||||
|
||||
// CSV export
|
||||
s.mux.HandleFunc("/web/export/csv", s.handleExportCSV)
|
||||
s.mux.HandleFunc("/web/export/xlsx", s.handleExportXLSX)
|
||||
|
||||
// Import
|
||||
s.mux.HandleFunc("/web/import/csv", s.handleImportCSV)
|
||||
|
||||
// Post-setup wizard
|
||||
s.mux.HandleFunc("/web/setup/wizard", s.handleSetupWizard)
|
||||
s.mux.HandleFunc("/web/setup/wizard/save", s.handleSetupWizardSave)
|
||||
|
||||
// Bank statement import
|
||||
s.mux.HandleFunc("/web/bank_statement/import", s.handleBankStatementImport)
|
||||
|
||||
// Reports (HTML and PDF report rendering)
|
||||
s.mux.HandleFunc("/report/", s.handleReport)
|
||||
@@ -137,10 +151,16 @@ func (s *Server) registerRoutes() {
|
||||
// Logout & Account
|
||||
s.mux.HandleFunc("/web/session/logout", s.handleLogout)
|
||||
s.mux.HandleFunc("/web/session/account", s.handleSessionAccount)
|
||||
s.mux.HandleFunc("/web/session/switch_company", s.handleSwitchCompany)
|
||||
|
||||
// Health check
|
||||
s.mux.HandleFunc("/health", s.handleHealth)
|
||||
|
||||
// Portal routes (external user access)
|
||||
s.registerPortalRoutes()
|
||||
s.registerPortalSignupRoutes()
|
||||
s.registerBusRoutes()
|
||||
|
||||
// Static files (catch-all for /<addon>/static/...)
|
||||
// NOTE: must be last since it's a broad pattern
|
||||
}
|
||||
@@ -255,13 +275,14 @@ func (s *Server) handleCallKW(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract UID from session, default to 1 (admin) if no session
|
||||
uid := int64(1)
|
||||
companyID := int64(1)
|
||||
if sess := GetSession(r); sess != nil {
|
||||
uid = sess.UID
|
||||
companyID = sess.CompanyID
|
||||
// Extract UID from session — reject if no session (defense in depth)
|
||||
sess := GetSession(r)
|
||||
if sess == nil {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: 100, Message: "Session expired"})
|
||||
return
|
||||
}
|
||||
uid := sess.UID
|
||||
companyID := sess.CompanyID
|
||||
|
||||
// Create environment for this request
|
||||
env, err := orm.NewEnvironment(r.Context(), orm.EnvConfig{
|
||||
@@ -294,6 +315,36 @@ func (s *Server) handleCallKW(w http.ResponseWriter, r *http.Request) {
|
||||
s.writeJSONRPC(w, req.ID, result, nil)
|
||||
}
|
||||
|
||||
// sensitiveFields lists fields that only admin (uid=1) may write to.
|
||||
// Prevents privilege escalation via field manipulation.
|
||||
var sensitiveFields = map[string]map[string]bool{
|
||||
"ir.cron": {"user_id": true, "model_name": true, "method_name": true},
|
||||
"ir.model.access": {"group_id": true, "perm_read": true, "perm_write": true, "perm_create": true, "perm_unlink": true},
|
||||
"ir.rule": {"domain_force": true, "groups": true, "perm_read": true, "perm_write": true, "perm_create": true, "perm_unlink": true},
|
||||
"res.users": {"groups_id": true},
|
||||
"res.groups": {"users": true},
|
||||
}
|
||||
|
||||
// checkSensitiveFields blocks non-admin users from writing protected fields.
|
||||
func checkSensitiveFields(env *orm.Environment, model string, vals orm.Values) *RPCError {
|
||||
if env.UID() == 1 || env.IsSuperuser() {
|
||||
return nil
|
||||
}
|
||||
fields, ok := sensitiveFields[model]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
for field := range vals {
|
||||
if fields[field] {
|
||||
return &RPCError{
|
||||
Code: 403,
|
||||
Message: fmt.Sprintf("Access Denied: field %q on %s is admin-only", field, model),
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkAccess verifies the current user has permission for the operation.
|
||||
// Mirrors: odoo/addons/base/models/ir_model.py IrModelAccess.check()
|
||||
func (s *Server) checkAccess(env *orm.Environment, model, method string) *RPCError {
|
||||
@@ -317,8 +368,22 @@ func (s *Server) checkAccess(env *orm.Environment, model, method string) *RPCErr
|
||||
`SELECT COUNT(*) FROM ir_model_access a
|
||||
JOIN ir_model m ON m.id = a.model_id
|
||||
WHERE m.model = $1`, model).Scan(&count)
|
||||
if err != nil || count == 0 {
|
||||
return nil // No ACLs defined → open access (like Odoo superuser mode)
|
||||
if err != nil {
|
||||
// DB error → deny access (fail-closed)
|
||||
log.Printf("access: DB error checking ACL for model %s: %v", model, err)
|
||||
return &RPCError{
|
||||
Code: 403,
|
||||
Message: fmt.Sprintf("Access Denied: %s on %s (internal error)", method, model),
|
||||
}
|
||||
}
|
||||
if count == 0 {
|
||||
// No ACL rules defined for this model → deny (fail-closed).
|
||||
// All models should have ACL seed data via seedACLRules().
|
||||
log.Printf("access: no ACL for model %s, denying (fail-closed)", model)
|
||||
return &RPCError{
|
||||
Code: 403,
|
||||
Message: fmt.Sprintf("Access Denied: no ACL rules for %s", model),
|
||||
}
|
||||
}
|
||||
|
||||
// Check if user's groups grant permission
|
||||
@@ -334,7 +399,11 @@ func (s *Server) checkAccess(env *orm.Environment, model, method string) *RPCErr
|
||||
AND (a.group_id IS NULL OR gu.res_users_id = $2)
|
||||
)`, perm), model, env.UID()).Scan(&granted)
|
||||
if err != nil {
|
||||
return nil // On error, allow (fail-open for now)
|
||||
log.Printf("access: DB error checking ACL grant for model %s: %v", model, err)
|
||||
return &RPCError{
|
||||
Code: 403,
|
||||
Message: fmt.Sprintf("Access Denied: %s on %s (internal error)", method, model),
|
||||
}
|
||||
}
|
||||
if !granted {
|
||||
return &RPCError{
|
||||
@@ -379,10 +448,57 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
|
||||
|
||||
switch params.Method {
|
||||
case "has_group":
|
||||
// Always return true for admin user, stub for now
|
||||
return true, nil
|
||||
// Check if current user belongs to the given group.
|
||||
// Mirrors: odoo/orm/models.py BaseModel.user_has_groups()
|
||||
groupXMLID := ""
|
||||
if len(params.Args) > 0 {
|
||||
groupXMLID, _ = params.Args[0].(string)
|
||||
}
|
||||
if groupXMLID == "" {
|
||||
return false, nil
|
||||
}
|
||||
// Admin always has all groups
|
||||
if env.UID() == 1 {
|
||||
return true, nil
|
||||
}
|
||||
// Parse "module.xml_id" format
|
||||
parts := strings.SplitN(groupXMLID, ".", 2)
|
||||
if len(parts) != 2 {
|
||||
return false, nil
|
||||
}
|
||||
// Query: does user belong to this group?
|
||||
var exists bool
|
||||
err := env.Tx().QueryRow(env.Ctx(),
|
||||
`SELECT EXISTS(
|
||||
SELECT 1 FROM res_groups_res_users_rel gur
|
||||
JOIN ir_model_data imd ON imd.res_id = gur.res_groups_id AND imd.model = 'res.groups'
|
||||
WHERE gur.res_users_id = $1 AND imd.module = $2 AND imd.name = $3
|
||||
)`, env.UID(), parts[0], parts[1]).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, nil
|
||||
}
|
||||
return exists, nil
|
||||
|
||||
case "check_access_rights":
|
||||
// Check if current user has the given access right on this model.
|
||||
// Mirrors: odoo/orm/models.py BaseModel.check_access_rights()
|
||||
operation := "read"
|
||||
if len(params.Args) > 0 {
|
||||
if op, ok := params.Args[0].(string); ok {
|
||||
operation = op
|
||||
}
|
||||
}
|
||||
raiseException := true
|
||||
if v, ok := params.KW["raise_exception"].(bool); ok {
|
||||
raiseException = v
|
||||
}
|
||||
accessErr := s.checkAccess(env, params.Model, operation)
|
||||
if accessErr != nil {
|
||||
if raiseException {
|
||||
return nil, accessErr
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
|
||||
case "fields_get":
|
||||
@@ -404,6 +520,11 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
|
||||
vals := parseValuesAt(params.Args, 1)
|
||||
spec, _ := params.KW["specification"].(map[string]interface{})
|
||||
|
||||
// Field-level access control
|
||||
if err := checkSensitiveFields(env, params.Model, vals); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(ids) > 0 && ids[0] > 0 {
|
||||
// Update existing record(s)
|
||||
err := rs.Browse(ids...).Write(vals)
|
||||
@@ -513,6 +634,9 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
|
||||
|
||||
case "create":
|
||||
vals := parseValues(params.Args)
|
||||
if err := checkSensitiveFields(env, params.Model, vals); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
record, err := rs.Create(vals)
|
||||
if err != nil {
|
||||
return nil, &RPCError{Code: -32000, Message: err.Error()}
|
||||
@@ -522,6 +646,9 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
|
||||
case "write":
|
||||
ids := parseIDs(params.Args)
|
||||
vals := parseValuesAt(params.Args, 1)
|
||||
if err := checkSensitiveFields(env, params.Model, vals); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err := rs.Browse(ids...).Write(vals)
|
||||
if err != nil {
|
||||
return nil, &RPCError{Code: -32000, Message: err.Error()}
|
||||
@@ -645,9 +772,33 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
|
||||
}, nil
|
||||
|
||||
case "get_formview_id":
|
||||
return false, nil
|
||||
// Return the default form view ID for this model.
|
||||
// Mirrors: odoo/orm/models.py BaseModel.get_formview_id()
|
||||
var viewID *int64
|
||||
err := env.Tx().QueryRow(env.Ctx(),
|
||||
`SELECT id FROM ir_ui_view
|
||||
WHERE model = $1 AND type = 'form' AND active = true
|
||||
ORDER BY priority, id LIMIT 1`,
|
||||
params.Model).Scan(&viewID)
|
||||
if err != nil || viewID == nil {
|
||||
return false, nil
|
||||
}
|
||||
return *viewID, nil
|
||||
|
||||
case "action_get":
|
||||
// Try registered method first (e.g. res.users has its own action_get).
|
||||
// Mirrors: odoo/addons/base/models/res_users.py action_get()
|
||||
model := orm.Registry.Get(params.Model)
|
||||
if model != nil && model.Methods != nil {
|
||||
if method, ok := model.Methods["action_get"]; ok {
|
||||
ids := parseIDs(params.Args)
|
||||
result, err := method(rs.Browse(ids...), params.Args[1:]...)
|
||||
if err != nil {
|
||||
return nil, &RPCError{Code: -32000, Message: err.Error()}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
|
||||
case "name_create":
|
||||
@@ -665,10 +816,48 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
|
||||
return []interface{}{created.ID(), nameStr}, nil
|
||||
|
||||
case "read_progress_bar":
|
||||
return map[string]interface{}{}, nil
|
||||
return s.handleReadProgressBar(rs, params)
|
||||
|
||||
case "activity_format":
|
||||
return []interface{}{}, nil
|
||||
ids := parseIDs(params.Args)
|
||||
if len(ids) == 0 {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
// Search activities for this model/record
|
||||
actRS := env.Model("mail.activity")
|
||||
var allActivities []orm.Values
|
||||
for _, id := range ids {
|
||||
domain := orm.And(
|
||||
orm.Leaf("res_model", "=", params.Model),
|
||||
orm.Leaf("res_id", "=", id),
|
||||
orm.Leaf("done", "=", false),
|
||||
)
|
||||
found, err := actRS.Search(domain, orm.SearchOpts{Order: "date_deadline"})
|
||||
if err != nil || found.IsEmpty() {
|
||||
continue
|
||||
}
|
||||
records, err := found.Read([]string{"id", "res_model", "res_id", "activity_type_id", "summary", "note", "date_deadline", "user_id", "state"})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
allActivities = append(allActivities, records...)
|
||||
}
|
||||
if allActivities == nil {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
// Format M2O fields
|
||||
actSpec := map[string]interface{}{
|
||||
"activity_type_id": map[string]interface{}{},
|
||||
"user_id": map[string]interface{}{},
|
||||
}
|
||||
formatM2OFields(env, "mail.activity", allActivities, actSpec)
|
||||
formatDateFields("mail.activity", allActivities)
|
||||
normalizeNullFields("mail.activity", allActivities)
|
||||
actResult := make([]interface{}, len(allActivities))
|
||||
for i, a := range allActivities {
|
||||
actResult[i] = a
|
||||
}
|
||||
return actResult, nil
|
||||
|
||||
case "action_archive":
|
||||
ids := parseIDs(params.Args)
|
||||
@@ -697,6 +886,199 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
|
||||
}
|
||||
return created.ID(), nil
|
||||
|
||||
case "web_resequence":
|
||||
// Resequence records by their IDs (drag&drop reordering).
|
||||
// Mirrors: odoo/addons/web/models/models.py web_resequence()
|
||||
ids := parseIDs(params.Args)
|
||||
if len(ids) == 0 {
|
||||
return []orm.Values{}, nil
|
||||
}
|
||||
|
||||
// Parse field_name (default "sequence")
|
||||
fieldName := "sequence"
|
||||
if v, ok := params.KW["field_name"].(string); ok {
|
||||
fieldName = v
|
||||
}
|
||||
|
||||
// Parse offset (default 0)
|
||||
offset := 0
|
||||
if v, ok := params.KW["offset"].(float64); ok {
|
||||
offset = int(v)
|
||||
}
|
||||
|
||||
// Check if field exists on the model
|
||||
model := orm.Registry.Get(params.Model)
|
||||
if model == nil || model.GetField(fieldName) == nil {
|
||||
return []orm.Values{}, nil
|
||||
}
|
||||
|
||||
// Update sequence for each record in order
|
||||
for i, id := range ids {
|
||||
if err := rs.Browse(id).Write(orm.Values{fieldName: offset + i}); err != nil {
|
||||
return nil, &RPCError{Code: -32000, Message: err.Error()}
|
||||
}
|
||||
}
|
||||
|
||||
// Return records via web_read
|
||||
spec, _ := params.KW["specification"].(map[string]interface{})
|
||||
readParams := CallKWParams{
|
||||
Model: params.Model,
|
||||
Method: "web_read",
|
||||
Args: []interface{}{ids},
|
||||
KW: map[string]interface{}{"specification": spec},
|
||||
}
|
||||
return handleWebRead(env, params.Model, readParams)
|
||||
|
||||
case "message_post":
|
||||
// Post a message on the record's chatter.
|
||||
// Mirrors: odoo/addons/mail/models/mail_thread.py message_post()
|
||||
ids := parseIDs(params.Args)
|
||||
if len(ids) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
body, _ := params.KW["body"].(string)
|
||||
messageType := "comment"
|
||||
if v, _ := params.KW["message_type"].(string); v != "" {
|
||||
messageType = v
|
||||
}
|
||||
|
||||
// Get author from current user's partner_id
|
||||
var authorID int64
|
||||
if err := env.Tx().QueryRow(env.Ctx(),
|
||||
`SELECT partner_id FROM res_users WHERE id = $1`, env.UID(),
|
||||
).Scan(&authorID); err != nil {
|
||||
log.Printf("warning: message_post author lookup failed: %v", err)
|
||||
}
|
||||
|
||||
// Create mail.message linked to the current model/record
|
||||
var msgID int64
|
||||
err := env.Tx().QueryRow(env.Ctx(),
|
||||
`INSERT INTO mail_message (model, res_id, body, message_type, author_id, date, create_uid, write_uid, create_date, write_date)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW(), $6, $6, NOW(), NOW())
|
||||
RETURNING id`,
|
||||
params.Model, ids[0], body, messageType, authorID, env.UID(),
|
||||
).Scan(&msgID)
|
||||
if err != nil {
|
||||
return nil, &RPCError{Code: -32000, Message: err.Error()}
|
||||
}
|
||||
return msgID, nil
|
||||
|
||||
case "_message_get_thread":
|
||||
// Get messages for a record's chatter.
|
||||
// Mirrors: odoo/addons/mail/models/mail_thread.py
|
||||
ids := parseIDs(params.Args)
|
||||
if len(ids) == 0 {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
|
||||
rows, err := env.Tx().Query(env.Ctx(),
|
||||
`SELECT m.id, m.body, m.message_type, m.date,
|
||||
m.author_id, COALESCE(p.name, ''),
|
||||
COALESCE(m.subject, ''), COALESCE(m.email_from, '')
|
||||
FROM mail_message m
|
||||
LEFT JOIN res_partner p ON p.id = m.author_id
|
||||
WHERE m.model = $1 AND m.res_id = $2
|
||||
ORDER BY m.id DESC`,
|
||||
params.Model, ids[0],
|
||||
)
|
||||
if err != nil {
|
||||
return nil, &RPCError{Code: -32000, Message: err.Error()}
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var body, msgType, subject, emailFrom string
|
||||
var date interface{}
|
||||
var authorID int64
|
||||
var authorName string
|
||||
|
||||
if scanErr := rows.Scan(&id, &body, &msgType, &date, &authorID, &authorName, &subject, &emailFrom); scanErr != nil {
|
||||
continue
|
||||
}
|
||||
msg := map[string]interface{}{
|
||||
"id": id,
|
||||
"body": body,
|
||||
"message_type": msgType,
|
||||
"date": date,
|
||||
"subject": subject,
|
||||
"email_from": emailFrom,
|
||||
}
|
||||
if authorID > 0 {
|
||||
msg["author_id"] = []interface{}{authorID, authorName}
|
||||
} else {
|
||||
msg["author_id"] = false
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
if messages == nil {
|
||||
messages = []map[string]interface{}{}
|
||||
}
|
||||
return messages, nil
|
||||
|
||||
case "read_followers":
|
||||
ids := parseIDs(params.Args)
|
||||
if len(ids) == 0 {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
// Search followers for this model/record
|
||||
followerRS := env.Model("mail.followers")
|
||||
domain := orm.And(
|
||||
orm.Leaf("res_model", "=", params.Model),
|
||||
orm.Leaf("res_id", "in", ids),
|
||||
)
|
||||
found, err := followerRS.Search(domain, orm.SearchOpts{Limit: 100})
|
||||
if err != nil || found.IsEmpty() {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
followerRecords, err := found.Read([]string{"id", "res_model", "res_id", "partner_id"})
|
||||
if err != nil {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
followerSpec := map[string]interface{}{"partner_id": map[string]interface{}{}}
|
||||
formatM2OFields(env, "mail.followers", followerRecords, followerSpec)
|
||||
normalizeNullFields("mail.followers", followerRecords)
|
||||
followerResult := make([]interface{}, len(followerRecords))
|
||||
for i, r := range followerRecords {
|
||||
followerResult[i] = r
|
||||
}
|
||||
return followerResult, nil
|
||||
|
||||
case "get_activity_data":
|
||||
// Return activity summary data for records.
|
||||
// Mirrors: odoo/addons/mail/models/mail_activity_mixin.py
|
||||
emptyResult := map[string]interface{}{
|
||||
"activity_types": []interface{}{},
|
||||
"activity_res_ids": map[string]interface{}{},
|
||||
"grouped_activities": map[string]interface{}{},
|
||||
}
|
||||
|
||||
ids := parseIDs(params.Args)
|
||||
if len(ids) == 0 {
|
||||
return emptyResult, nil
|
||||
}
|
||||
|
||||
// Get activity types
|
||||
typeRS := env.Model("mail.activity.type")
|
||||
types, err := typeRS.Search(nil, orm.SearchOpts{Order: "sequence, id"})
|
||||
if err != nil || types.IsEmpty() {
|
||||
return emptyResult, nil
|
||||
}
|
||||
typeRecords, _ := types.Read([]string{"id", "name"})
|
||||
|
||||
typeList := make([]interface{}, len(typeRecords))
|
||||
for i, t := range typeRecords {
|
||||
typeList[i] = t
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"activity_types": typeList,
|
||||
"activity_res_ids": map[string]interface{}{},
|
||||
"grouped_activities": map[string]interface{}{},
|
||||
}, nil
|
||||
|
||||
default:
|
||||
// Try registered business methods on the model.
|
||||
// Mirrors: odoo/service/model.py call_kw() + odoo/addons/web/controllers/dataset.py call_button()
|
||||
@@ -732,6 +1114,58 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
|
||||
|
||||
// --- Session / Auth Endpoints ---
|
||||
|
||||
// loginAttemptInfo tracks login attempts for rate limiting.
|
||||
type loginAttemptInfo struct {
|
||||
Count int
|
||||
LastTime time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
loginAttempts = make(map[string]loginAttemptInfo)
|
||||
loginAttemptsMu sync.Mutex
|
||||
)
|
||||
|
||||
// checkLoginRateLimit returns false if the login is rate-limited (too many attempts).
|
||||
func (s *Server) checkLoginRateLimit(login string) bool {
|
||||
loginAttemptsMu.Lock()
|
||||
defer loginAttemptsMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Periodic cleanup: evict stale entries (>15 min old) to prevent unbounded growth
|
||||
if len(loginAttempts) > 100 {
|
||||
for k, v := range loginAttempts {
|
||||
if now.Sub(v.LastTime) > 15*time.Minute {
|
||||
delete(loginAttempts, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info := loginAttempts[login]
|
||||
|
||||
// Reset after 15 minutes
|
||||
if now.Sub(info.LastTime) > 15*time.Minute {
|
||||
info = loginAttemptInfo{}
|
||||
}
|
||||
|
||||
// Max 10 attempts per 15 minutes
|
||||
if info.Count >= 10 {
|
||||
return false // Rate limited
|
||||
}
|
||||
|
||||
info.Count++
|
||||
info.LastTime = now
|
||||
loginAttempts[login] = info
|
||||
return true
|
||||
}
|
||||
|
||||
// resetLoginRateLimit clears the rate limit counter on successful login.
|
||||
func (s *Server) resetLoginRateLimit(login string) {
|
||||
loginAttemptsMu.Lock()
|
||||
defer loginAttemptsMu.Unlock()
|
||||
delete(loginAttempts, login)
|
||||
}
|
||||
|
||||
func (s *Server) handleAuthenticate(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
@@ -754,6 +1188,14 @@ func (s *Server) handleAuthenticate(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Rate limit login attempts
|
||||
if !s.checkLoginRateLimit(params.Login) {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{
|
||||
Code: 429, Message: "Too many login attempts. Please try again later.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Query user by login
|
||||
var uid int64
|
||||
var companyID int64
|
||||
@@ -776,16 +1218,40 @@ func (s *Server) handleAuthenticate(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check password (support both bcrypt and plaintext for migration)
|
||||
if !tools.CheckPassword(hashedPw, params.Password) && hashedPw != params.Password {
|
||||
// Check password (bcrypt only — no plaintext fallback)
|
||||
if !tools.CheckPassword(hashedPw, params.Password) {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{
|
||||
Code: 100, Message: "Access Denied: invalid login or password",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Successful login – reset rate limiter
|
||||
s.resetLoginRateLimit(params.Login)
|
||||
|
||||
// Query allowed companies for the user
|
||||
allowedCompanyIDs := []int64{companyID}
|
||||
rows, err := s.pool.Query(r.Context(),
|
||||
`SELECT DISTINCT c.id FROM res_company c
|
||||
WHERE c.active = true
|
||||
ORDER BY c.id`)
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
var ids []int64
|
||||
for rows.Next() {
|
||||
var cid int64
|
||||
if rows.Scan(&cid) == nil {
|
||||
ids = append(ids, cid)
|
||||
}
|
||||
}
|
||||
if len(ids) > 0 {
|
||||
allowedCompanyIDs = ids
|
||||
}
|
||||
}
|
||||
|
||||
// Create session
|
||||
sess := s.sessions.New(uid, companyID, params.Login)
|
||||
sess.AllowedCompanyIDs = allowedCompanyIDs
|
||||
|
||||
// Set session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
@@ -793,6 +1259,7 @@ func (s *Server) handleAuthenticate(w http.ResponseWriter, r *http.Request) {
|
||||
Value: sess.ID,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
@@ -857,6 +1324,7 @@ func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
})
|
||||
http.Redirect(w, r, "/web/login", http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -13,12 +13,14 @@ import (
|
||||
|
||||
// Session represents an authenticated user session.
|
||||
type Session struct {
|
||||
ID string
|
||||
UID int64
|
||||
CompanyID int64
|
||||
Login string
|
||||
CreatedAt time.Time
|
||||
LastActivity time.Time
|
||||
ID string
|
||||
UID int64
|
||||
CompanyID int64
|
||||
AllowedCompanyIDs []int64
|
||||
Login string
|
||||
CSRFToken string
|
||||
CreatedAt time.Time
|
||||
LastActivity time.Time
|
||||
}
|
||||
|
||||
// SessionStore is a session store with an in-memory cache backed by PostgreSQL.
|
||||
@@ -47,10 +49,15 @@ func InitSessionTable(ctx context.Context, pool *pgxpool.Pool) error {
|
||||
uid INT8 NOT NULL,
|
||||
company_id INT8 NOT NULL,
|
||||
login VARCHAR(255),
|
||||
csrf_token VARCHAR(64) DEFAULT '',
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
last_seen TIMESTAMP DEFAULT NOW()
|
||||
)
|
||||
`)
|
||||
if err == nil {
|
||||
// Add csrf_token column if table already exists without it
|
||||
pool.Exec(ctx, `ALTER TABLE sessions ADD COLUMN IF NOT EXISTS csrf_token VARCHAR(64) DEFAULT ''`)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -67,6 +74,7 @@ func (s *SessionStore) New(uid, companyID int64, login string) *Session {
|
||||
UID: uid,
|
||||
CompanyID: companyID,
|
||||
Login: login,
|
||||
CSRFToken: generateToken(),
|
||||
CreatedAt: now,
|
||||
LastActivity: now,
|
||||
}
|
||||
@@ -81,10 +89,10 @@ func (s *SessionStore) New(uid, companyID int64, login string) *Session {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_, err := s.pool.Exec(ctx,
|
||||
`INSERT INTO sessions (id, uid, company_id, login, created_at, last_seen)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
`INSERT INTO sessions (id, uid, company_id, login, csrf_token, created_at, last_seen)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (id) DO NOTHING`,
|
||||
token, uid, companyID, login, now, now)
|
||||
token, uid, companyID, login, sess.CSRFToken, now, now)
|
||||
if err != nil {
|
||||
log.Printf("session: failed to persist session to DB: %v", err)
|
||||
}
|
||||
@@ -106,20 +114,23 @@ func (s *SessionStore) Get(id string) *Session {
|
||||
s.Delete(id)
|
||||
return nil
|
||||
}
|
||||
// Update last activity
|
||||
|
||||
now := time.Now()
|
||||
needsDBUpdate := time.Since(sess.LastActivity) > 30*time.Second
|
||||
|
||||
// Update last activity in memory
|
||||
s.mu.Lock()
|
||||
sess.LastActivity = now
|
||||
s.mu.Unlock()
|
||||
|
||||
// Update last_seen in DB asynchronously
|
||||
if s.pool != nil {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.pool.Exec(ctx,
|
||||
`UPDATE sessions SET last_seen = $1 WHERE id = $2`, now, id)
|
||||
}()
|
||||
// Throttle DB writes: only persist every 30s to avoid per-request overhead
|
||||
if needsDBUpdate && s.pool != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := s.pool.Exec(ctx,
|
||||
`UPDATE sessions SET last_seen = $1 WHERE id = $2`, now, id); err != nil {
|
||||
log.Printf("session: failed to update last_seen in DB: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return sess
|
||||
@@ -134,14 +145,20 @@ func (s *SessionStore) Get(id string) *Session {
|
||||
defer cancel()
|
||||
|
||||
sess = &Session{}
|
||||
var csrfToken string
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT id, uid, company_id, login, created_at, last_seen
|
||||
`SELECT id, uid, company_id, login, COALESCE(csrf_token, ''), created_at, last_seen
|
||||
FROM sessions WHERE id = $1`, id).Scan(
|
||||
&sess.ID, &sess.UID, &sess.CompanyID, &sess.Login,
|
||||
&sess.ID, &sess.UID, &sess.CompanyID, &sess.Login, &csrfToken,
|
||||
&sess.CreatedAt, &sess.LastActivity)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if csrfToken != "" {
|
||||
sess.CSRFToken = csrfToken
|
||||
} else {
|
||||
sess.CSRFToken = generateToken()
|
||||
}
|
||||
|
||||
// Check TTL
|
||||
if time.Since(sess.LastActivity) > s.ttl {
|
||||
@@ -149,18 +166,18 @@ func (s *SessionStore) Get(id string) *Session {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update last activity
|
||||
// Update last activity and add to memory cache
|
||||
now := time.Now()
|
||||
sess.LastActivity = now
|
||||
|
||||
// Add to memory cache
|
||||
s.mu.Lock()
|
||||
sess.LastActivity = now
|
||||
s.sessions[id] = sess
|
||||
s.mu.Unlock()
|
||||
|
||||
// Update last_seen in DB
|
||||
s.pool.Exec(ctx,
|
||||
`UPDATE sessions SET last_seen = $1 WHERE id = $2`, now, id)
|
||||
if _, err := s.pool.Exec(ctx,
|
||||
`UPDATE sessions SET last_seen = $1 WHERE id = $2`, now, id); err != nil {
|
||||
log.Printf("session: failed to update last_seen in DB: %v", err)
|
||||
}
|
||||
|
||||
return sess
|
||||
}
|
||||
|
||||
@@ -6,14 +6,17 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"odoo-go/pkg/service"
|
||||
"odoo-go/pkg/tools"
|
||||
)
|
||||
|
||||
|
||||
var dbnamePattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]+$`)
|
||||
|
||||
// isSetupNeeded checks if the current database has been initialized.
|
||||
@@ -55,6 +58,16 @@ func (s *Server) handleDatabaseCreate(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate master password (default: "admin", configurable via ODOO_MASTER_PASSWORD env)
|
||||
masterPw := os.Getenv("ODOO_MASTER_PASSWORD")
|
||||
if masterPw == "" {
|
||||
masterPw = "admin"
|
||||
}
|
||||
if params.MasterPwd != masterPw {
|
||||
writeJSON(w, map[string]string{"error": "Invalid master password"})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate
|
||||
if params.Login == "" || params.Password == "" {
|
||||
writeJSON(w, map[string]string{"error": "Email and password are required"})
|
||||
@@ -111,7 +124,10 @@ func (s *Server) handleDatabaseCreate(w http.ResponseWriter, r *http.Request) {
|
||||
domain := parts[1]
|
||||
domainParts := strings.Split(domain, ".")
|
||||
if len(domainParts) > 0 {
|
||||
companyName = strings.Title(domainParts[0])
|
||||
name := domainParts[0]
|
||||
if len(name) > 0 {
|
||||
companyName = strings.ToUpper(name[:1]) + name[1:]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -175,6 +191,195 @@ func writeJSON(w http.ResponseWriter, v interface{}) {
|
||||
json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
|
||||
// postSetupDone caches the result of isPostSetupNeeded to avoid a DB query on every request.
|
||||
var postSetupDone atomic.Bool
|
||||
|
||||
// isPostSetupNeeded checks if the company still has default values (needs configuration).
|
||||
func (s *Server) isPostSetupNeeded() bool {
|
||||
if postSetupDone.Load() {
|
||||
return false
|
||||
}
|
||||
var name string
|
||||
err := s.pool.QueryRow(context.Background(),
|
||||
`SELECT COALESCE(name, '') FROM res_company WHERE id = 1`).Scan(&name)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
needed := name == "" || name == "My Company" || strings.HasPrefix(name, "My ")
|
||||
if !needed {
|
||||
postSetupDone.Store(true)
|
||||
}
|
||||
return needed
|
||||
}
|
||||
|
||||
// handleSetupWizard serves the post-setup configuration wizard.
|
||||
// Shown after first login when the company has not been configured yet.
|
||||
// Mirrors: odoo/addons/base_setup/views/res_config_settings_views.xml
|
||||
func (s *Server) handleSetupWizard(w http.ResponseWriter, r *http.Request) {
|
||||
sess := GetSession(r)
|
||||
if sess == nil {
|
||||
http.Redirect(w, r, "/web/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Load current company data
|
||||
var companyName, street, city, zip, phone, email, website, vat string
|
||||
var countryID int64
|
||||
s.pool.QueryRow(context.Background(),
|
||||
`SELECT COALESCE(name,''), COALESCE(street,''), COALESCE(city,''), COALESCE(zip,''),
|
||||
COALESCE(phone,''), COALESCE(email,''), COALESCE(website,''), COALESCE(vat,''),
|
||||
COALESCE(country_id, 0)
|
||||
FROM res_company WHERE id = $1`, sess.CompanyID,
|
||||
).Scan(&companyName, &street, &city, &zip, &phone, &email, &website, &vat, &countryID)
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
esc := htmlEscape
|
||||
fmt.Fprintf(w, setupWizardHTML,
|
||||
esc(companyName), esc(street), esc(city), esc(zip), esc(phone), esc(email), esc(website), esc(vat))
|
||||
}
|
||||
|
||||
// handleSetupWizardSave saves the post-setup wizard data.
|
||||
func (s *Server) handleSetupWizardSave(w http.ResponseWriter, r *http.Request) {
|
||||
sess := GetSession(r)
|
||||
if sess == nil {
|
||||
writeJSON(w, map[string]string{"error": "Not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var params struct {
|
||||
CompanyName string `json:"company_name"`
|
||||
Street string `json:"street"`
|
||||
City string `json:"city"`
|
||||
Zip string `json:"zip"`
|
||||
Phone string `json:"phone"`
|
||||
Email string `json:"email"`
|
||||
Website string `json:"website"`
|
||||
Vat string `json:"vat"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
|
||||
writeJSON(w, map[string]string{"error": "Invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
if params.CompanyName == "" {
|
||||
writeJSON(w, map[string]string{"error": "Company name is required"})
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.pool.Exec(context.Background(),
|
||||
`UPDATE res_company SET name=$1, street=$2, city=$3, zip=$4, phone=$5, email=$6, website=$7, vat=$8
|
||||
WHERE id = $9`,
|
||||
params.CompanyName, params.Street, params.City, params.Zip,
|
||||
params.Phone, params.Email, params.Website, params.Vat, sess.CompanyID)
|
||||
if err != nil {
|
||||
writeJSON(w, map[string]string{"error": fmt.Sprintf("Save error: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// Also update the partner linked to the company
|
||||
s.pool.Exec(context.Background(),
|
||||
`UPDATE res_partner SET name=$1, street=$2, city=$3, zip=$4, phone=$5, email=$6, website=$7, vat=$8
|
||||
WHERE id = (SELECT partner_id FROM res_company WHERE id = $9)`,
|
||||
params.CompanyName, params.Street, params.City, params.Zip,
|
||||
params.Phone, params.Email, params.Website, params.Vat, sess.CompanyID)
|
||||
|
||||
postSetupDone.Store(true) // Mark setup as done so we don't redirect again
|
||||
writeJSON(w, map[string]interface{}{"status": "ok", "redirect": "/odoo"})
|
||||
}
|
||||
|
||||
var setupWizardHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<title>Setup — Configure Your Company</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
background: #f0eeee; display: flex; align-items: center; justify-content: center; min-height: 100vh; }
|
||||
.wizard { background: white; padding: 40px; border-radius: 8px; box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||
width: 100%%; max-width: 560px; }
|
||||
.wizard h1 { color: #71639e; margin-bottom: 6px; font-size: 24px; }
|
||||
.wizard .subtitle { color: #666; margin-bottom: 24px; font-size: 14px; }
|
||||
.wizard label { display: block; margin-bottom: 4px; font-weight: 500; color: #555; font-size: 13px; }
|
||||
.wizard input { width: 100%%; padding: 9px 12px; border: 1px solid #ddd; border-radius: 4px;
|
||||
font-size: 14px; margin-bottom: 14px; }
|
||||
.wizard input:focus { outline: none; border-color: #71639e; box-shadow: 0 0 0 2px rgba(113,99,158,0.2); }
|
||||
.wizard button { width: 100%%; padding: 14px; background: #71639e; color: white; border: none;
|
||||
border-radius: 4px; font-size: 16px; cursor: pointer; margin-top: 10px; }
|
||||
.wizard button:hover { background: #5f5387; }
|
||||
.wizard .skip { text-align: center; margin-top: 12px; }
|
||||
.wizard .skip a { color: #999; text-decoration: none; font-size: 13px; }
|
||||
.wizard .skip a:hover { color: #666; }
|
||||
.row { display: flex; gap: 12px; }
|
||||
.row > div { flex: 1; }
|
||||
.error { color: #dc3545; margin-bottom: 12px; display: none; text-align: center; font-size: 14px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="wizard">
|
||||
<h1>Configure Your Company</h1>
|
||||
<p class="subtitle">Set up your company information</p>
|
||||
<div id="error" class="error"></div>
|
||||
<form id="wizardForm">
|
||||
<label>Company Name *</label>
|
||||
<input type="text" id="company_name" value="%s" required/>
|
||||
|
||||
<label>Street</label>
|
||||
<input type="text" id="street" value="%s"/>
|
||||
|
||||
<div class="row">
|
||||
<div><label>City</label><input type="text" id="city" value="%s"/></div>
|
||||
<div><label>ZIP</label><input type="text" id="zip" value="%s"/></div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div><label>Phone</label><input type="tel" id="phone" value="%s"/></div>
|
||||
<div><label>Email</label><input type="email" id="email" value="%s"/></div>
|
||||
</div>
|
||||
|
||||
<label>Website</label>
|
||||
<input type="url" id="website" value="%s" placeholder="https://"/>
|
||||
|
||||
<label>Tax ID / VAT</label>
|
||||
<input type="text" id="vat" value="%s"/>
|
||||
|
||||
<button type="submit">Save & Continue</button>
|
||||
</form>
|
||||
<div class="skip"><a href="/odoo">Skip for now</a></div>
|
||||
</div>
|
||||
<script>
|
||||
document.getElementById('wizardForm').addEventListener('submit', function(e) {
|
||||
e.preventDefault();
|
||||
fetch('/web/setup/wizard/save', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({
|
||||
company_name: document.getElementById('company_name').value,
|
||||
street: document.getElementById('street').value,
|
||||
city: document.getElementById('city').value,
|
||||
zip: document.getElementById('zip').value,
|
||||
phone: document.getElementById('phone').value,
|
||||
email: document.getElementById('email').value,
|
||||
website: document.getElementById('website').value,
|
||||
vat: document.getElementById('vat').value
|
||||
})
|
||||
})
|
||||
.then(function(r) { return r.json(); })
|
||||
.then(function(result) {
|
||||
if (result.error) {
|
||||
var el = document.getElementById('error');
|
||||
el.textContent = result.error;
|
||||
el.style.display = 'block';
|
||||
} else {
|
||||
window.location.href = result.redirect || '/odoo';
|
||||
}
|
||||
});
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
// --- Database Manager HTML ---
|
||||
// Mirrors: odoo/addons/web/static/src/public/database_manager.create_form.qweb.html
|
||||
var databaseManagerHTML = `<!DOCTYPE html>
|
||||
|
||||
@@ -43,8 +43,9 @@ func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) {
|
||||
addonName := parts[0]
|
||||
filePath := parts[2]
|
||||
|
||||
// Security: prevent directory traversal
|
||||
if strings.Contains(filePath, "..") {
|
||||
// Security: prevent directory traversal in both addonName and filePath
|
||||
if strings.Contains(filePath, "..") || strings.Contains(addonName, "..") ||
|
||||
strings.Contains(addonName, "/") || strings.Contains(addonName, "\\") {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -88,7 +88,7 @@ func TestExtractImports(t *testing.T) {
|
||||
content := `import { Foo, Bar } from "@web/core/foo";
|
||||
import { Baz as Qux } from "@web/core/baz";
|
||||
const x = 1;`
|
||||
deps, requires, clean := extractImports(content)
|
||||
deps, requires, clean := extractImports("test.module", content)
|
||||
|
||||
if len(deps) != 2 {
|
||||
t.Fatalf("expected 2 deps, got %d: %v", len(deps), deps)
|
||||
@@ -120,7 +120,7 @@ const x = 1;`
|
||||
|
||||
t.Run("default import", func(t *testing.T) {
|
||||
content := `import Foo from "@web/core/foo";`
|
||||
deps, requires, _ := extractImports(content)
|
||||
deps, requires, _ := extractImports("test.module", content)
|
||||
|
||||
if len(deps) != 1 || deps[0] != "@web/core/foo" {
|
||||
t.Errorf("deps = %v, want [@web/core/foo]", deps)
|
||||
@@ -132,7 +132,7 @@ const x = 1;`
|
||||
|
||||
t.Run("namespace import", func(t *testing.T) {
|
||||
content := `import * as utils from "@web/core/utils";`
|
||||
deps, requires, _ := extractImports(content)
|
||||
deps, requires, _ := extractImports("test.module", content)
|
||||
|
||||
if len(deps) != 1 || deps[0] != "@web/core/utils" {
|
||||
t.Errorf("deps = %v, want [@web/core/utils]", deps)
|
||||
@@ -144,7 +144,7 @@ const x = 1;`
|
||||
|
||||
t.Run("side-effect import", func(t *testing.T) {
|
||||
content := `import "@web/core/setup";`
|
||||
deps, requires, _ := extractImports(content)
|
||||
deps, requires, _ := extractImports("test.module", content)
|
||||
|
||||
if len(deps) != 1 || deps[0] != "@web/core/setup" {
|
||||
t.Errorf("deps = %v, want [@web/core/setup]", deps)
|
||||
@@ -157,7 +157,7 @@ const x = 1;`
|
||||
t.Run("dedup deps", func(t *testing.T) {
|
||||
content := `import { Foo } from "@web/core/foo";
|
||||
import { Bar } from "@web/core/foo";`
|
||||
deps, _, _ := extractImports(content)
|
||||
deps, _, _ := extractImports("test.module", content)
|
||||
|
||||
if len(deps) != 1 {
|
||||
t.Errorf("expected deduped deps, got %v", deps)
|
||||
@@ -167,7 +167,7 @@ import { Bar } from "@web/core/foo";`
|
||||
|
||||
func TestTransformExports(t *testing.T) {
|
||||
t.Run("export class", func(t *testing.T) {
|
||||
got := transformExports("export class Foo extends Bar {")
|
||||
got, _ := transformExports("export class Foo extends Bar {")
|
||||
want := "const Foo = __exports.Foo = class Foo extends Bar {"
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
@@ -175,15 +175,18 @@ func TestTransformExports(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("export function", func(t *testing.T) {
|
||||
got := transformExports("export function doSomething(a, b) {")
|
||||
want := `__exports.doSomething = function doSomething(a, b) {`
|
||||
got, deferred := transformExports("export function doSomething(a, b) {")
|
||||
want := `function doSomething(a, b) {`
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
if len(deferred) != 1 || deferred[0] != "doSomething" {
|
||||
t.Errorf("deferred = %v, want [doSomething]", deferred)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("export const", func(t *testing.T) {
|
||||
got := transformExports("export const MAX_SIZE = 100;")
|
||||
got, _ := transformExports("export const MAX_SIZE = 100;")
|
||||
want := "const MAX_SIZE = __exports.MAX_SIZE = 100;"
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
@@ -191,7 +194,7 @@ func TestTransformExports(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("export let", func(t *testing.T) {
|
||||
got := transformExports("export let counter = 0;")
|
||||
got, _ := transformExports("export let counter = 0;")
|
||||
want := "let counter = __exports.counter = 0;"
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
@@ -199,7 +202,7 @@ func TestTransformExports(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("export default", func(t *testing.T) {
|
||||
got := transformExports("export default Foo;")
|
||||
got, _ := transformExports("export default Foo;")
|
||||
want := `__exports[Symbol.for("default")] = Foo;`
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
@@ -207,7 +210,7 @@ func TestTransformExports(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("export named", func(t *testing.T) {
|
||||
got := transformExports("export { Foo, Bar };")
|
||||
got, _ := transformExports("export { Foo, Bar };")
|
||||
if !strings.Contains(got, "__exports.Foo = Foo;") {
|
||||
t.Errorf("missing Foo export in: %s", got)
|
||||
}
|
||||
@@ -217,7 +220,7 @@ func TestTransformExports(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("export named with alias", func(t *testing.T) {
|
||||
got := transformExports("export { Foo as default };")
|
||||
got, _ := transformExports("export { Foo as default };")
|
||||
if !strings.Contains(got, "__exports.default = Foo;") {
|
||||
t.Errorf("missing aliased export in: %s", got)
|
||||
}
|
||||
|
||||
@@ -20,12 +20,27 @@ func (s *Server) handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse multipart form (max 128MB)
|
||||
if err := r.ParseMultipartForm(128 << 20); err != nil {
|
||||
// Limit upload size to 50MB
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 50<<20)
|
||||
|
||||
// Parse multipart form (max 50MB)
|
||||
if err := r.ParseMultipartForm(50 << 20); err != nil {
|
||||
http.Error(w, "File too large", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
|
||||
// CSRF validation for multipart form uploads.
|
||||
// Mirrors: odoo/http.py validate_csrf()
|
||||
sess := GetSession(r)
|
||||
if sess != nil {
|
||||
csrfToken := r.FormValue("csrf_token")
|
||||
if csrfToken != sess.CSRFToken {
|
||||
log.Printf("upload: CSRF token mismatch for uid=%d", sess.UID)
|
||||
http.Error(w, "CSRF validation failed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("ufile")
|
||||
if err != nil {
|
||||
http.Error(w, "No file uploaded", http.StatusBadRequest)
|
||||
|
||||
@@ -195,6 +195,12 @@ func generateDefaultView(modelName, viewType string) string {
|
||||
return generateDefaultPivotView(m)
|
||||
case "graph":
|
||||
return generateDefaultGraphView(m)
|
||||
case "calendar":
|
||||
return generateDefaultCalendarView(m)
|
||||
case "activity":
|
||||
return generateDefaultActivityView(m)
|
||||
case "dashboard":
|
||||
return generateDefaultDashboardView(m)
|
||||
default:
|
||||
return fmt.Sprintf("<%s><field name=\"id\"/></%s>", viewType, viewType)
|
||||
}
|
||||
@@ -530,6 +536,161 @@ func generateDefaultGraphView(m *orm.Model) string {
|
||||
return fmt.Sprintf("<graph>\n %s\n</graph>", strings.Join(fields, "\n "))
|
||||
}
|
||||
|
||||
// generateDefaultCalendarView creates a calendar view with auto-detected date fields.
|
||||
// The OWL CalendarArchParser requires date_start; date_stop and color are optional.
|
||||
// Mirrors: odoo/addons/web/static/src/views/calendar/calendar_arch_parser.js
|
||||
func generateDefaultCalendarView(m *orm.Model) string {
|
||||
// Auto-detect date_start field (priority order)
|
||||
dateStart := ""
|
||||
for _, candidate := range []string{"start", "date_start", "date_from", "date_order", "date_begin", "date"} {
|
||||
if f := m.GetField(candidate); f != nil && (f.Type == orm.TypeDatetime || f.Type == orm.TypeDate) {
|
||||
dateStart = candidate
|
||||
break
|
||||
}
|
||||
}
|
||||
if dateStart == "" {
|
||||
// Fallback: find any datetime/date field
|
||||
for _, name := range sortedFieldNames(m) {
|
||||
f := m.GetField(name)
|
||||
if f != nil && (f.Type == orm.TypeDatetime || f.Type == orm.TypeDate) && f.Name != "create_date" && f.Name != "write_date" {
|
||||
dateStart = name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if dateStart == "" {
|
||||
// No date field found — return minimal arch that won't crash
|
||||
return `<calendar date_start="create_date"><field name="display_name"/></calendar>`
|
||||
}
|
||||
|
||||
// Auto-detect date_stop field
|
||||
dateStop := ""
|
||||
for _, candidate := range []string{"stop", "date_stop", "date_to", "date_end"} {
|
||||
if f := m.GetField(candidate); f != nil && (f.Type == orm.TypeDatetime || f.Type == orm.TypeDate) {
|
||||
dateStop = candidate
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-detect color field (M2O fields make good color discriminators)
|
||||
colorField := ""
|
||||
for _, candidate := range []string{"color", "user_id", "partner_id", "stage_id"} {
|
||||
if f := m.GetField(candidate); f != nil {
|
||||
colorField = candidate
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-detect all_day field
|
||||
allDay := ""
|
||||
for _, candidate := range []string{"allday", "all_day"} {
|
||||
if f := m.GetField(candidate); f != nil && f.Type == orm.TypeBoolean {
|
||||
allDay = candidate
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Build attributes
|
||||
attrs := fmt.Sprintf(`date_start="%s"`, dateStart)
|
||||
if dateStop != "" {
|
||||
attrs += fmt.Sprintf(` date_stop="%s"`, dateStop)
|
||||
}
|
||||
if colorField != "" {
|
||||
attrs += fmt.Sprintf(` color="%s"`, colorField)
|
||||
}
|
||||
if allDay != "" {
|
||||
attrs += fmt.Sprintf(` all_day="%s"`, allDay)
|
||||
}
|
||||
|
||||
// Pick display fields for the calendar card
|
||||
var fields []string
|
||||
nameField := "display_name"
|
||||
if f := m.GetField("name"); f != nil {
|
||||
nameField = "name"
|
||||
}
|
||||
fields = append(fields, fmt.Sprintf(` <field name="%s"/>`, nameField))
|
||||
|
||||
if f := m.GetField("partner_id"); f != nil {
|
||||
fields = append(fields, ` <field name="partner_id" avatar_field="avatar_128"/>`)
|
||||
}
|
||||
if f := m.GetField("user_id"); f != nil && colorField != "user_id" {
|
||||
fields = append(fields, ` <field name="user_id"/>`)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("<calendar %s mode=\"month\">\n%s\n</calendar>",
|
||||
attrs, strings.Join(fields, "\n"))
|
||||
}
|
||||
|
||||
// generateDefaultActivityView creates a minimal activity view.
|
||||
// Mirrors: odoo/addons/mail/static/src/views/web_activity/activity_arch_parser.js
|
||||
func generateDefaultActivityView(m *orm.Model) string {
|
||||
nameField := "display_name"
|
||||
if f := m.GetField("name"); f != nil {
|
||||
nameField = "name"
|
||||
}
|
||||
return fmt.Sprintf(`<activity string="Activities">
|
||||
<templates>
|
||||
<div t-name="activity-box">
|
||||
<field name="%s"/>
|
||||
</div>
|
||||
</templates>
|
||||
</activity>`, nameField)
|
||||
}
|
||||
|
||||
// generateDefaultDashboardView creates a dashboard view with aggregate widgets.
|
||||
// Mirrors: odoo/addons/board/static/src/board_view.js
|
||||
func generateDefaultDashboardView(m *orm.Model) string {
|
||||
var widgets []string
|
||||
|
||||
// Add aggregate widgets for numeric fields
|
||||
for _, name := range sortedFieldNames(m) {
|
||||
f := m.GetField(name)
|
||||
if f == nil {
|
||||
continue
|
||||
}
|
||||
if (f.Type == orm.TypeFloat || f.Type == orm.TypeInteger || f.Type == orm.TypeMonetary) &&
|
||||
f.IsStored() && f.Name != "id" && f.Name != "sequence" &&
|
||||
f.Name != "create_uid" && f.Name != "write_uid" && f.Name != "company_id" {
|
||||
widgets = append(widgets, fmt.Sprintf(
|
||||
` <aggregate name="%s" field="%s" string="%s"/>`,
|
||||
f.Name, f.Name, f.String))
|
||||
if len(widgets) >= 6 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add a graph for the first groupable dimension
|
||||
var graphField string
|
||||
for _, name := range sortedFieldNames(m) {
|
||||
f := m.GetField(name)
|
||||
if f != nil && f.IsStored() && (f.Type == orm.TypeMany2one || f.Type == orm.TypeSelection) {
|
||||
graphField = name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
buf.WriteString("<dashboard>\n")
|
||||
if len(widgets) > 0 {
|
||||
buf.WriteString(" <group>\n")
|
||||
for _, w := range widgets {
|
||||
buf.WriteString(w + "\n")
|
||||
}
|
||||
buf.WriteString(" </group>\n")
|
||||
}
|
||||
if graphField != "" {
|
||||
buf.WriteString(fmt.Sprintf(` <view type="graph">
|
||||
<graph type="bar">
|
||||
<field name="%s"/>
|
||||
</graph>
|
||||
</view>
|
||||
`, graphField))
|
||||
}
|
||||
buf.WriteString("</dashboard>")
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// sortedFieldNames returns field names in alphabetical order for deterministic output.
|
||||
func sortedFieldNames(m *orm.Model) []string {
|
||||
fields := m.Fields()
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"odoo-go/pkg/orm"
|
||||
@@ -451,6 +452,110 @@ func (s *Server) handleReadGroup(rs *orm.Recordset, params CallKWParams) (interf
|
||||
}
|
||||
|
||||
if params.Method == "web_read_group" {
|
||||
// --- __fold support ---
|
||||
// If the first groupby is a Many2one whose comodel has a "fold" field,
|
||||
// add __fold to each group. Mirrors: odoo/addons/web/models/models.py
|
||||
if len(groupby) > 0 {
|
||||
fieldName := strings.SplitN(groupby[0], ":", 2)[0]
|
||||
m := rs.ModelDef()
|
||||
if m != nil {
|
||||
f := m.GetField(fieldName)
|
||||
if f != nil && f.Type == orm.TypeMany2one && f.Comodel != "" {
|
||||
comodel := orm.Registry.Get(f.Comodel)
|
||||
if comodel != nil && comodel.GetField("fold") != nil {
|
||||
addFoldInfo(rs.Env(), f.Comodel, groupby[0], groups)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- __records for auto_unfold ---
|
||||
autoUnfold := false
|
||||
if v, ok := params.KW["auto_unfold"].(bool); ok {
|
||||
autoUnfold = v
|
||||
}
|
||||
if autoUnfold {
|
||||
unfoldReadSpec, _ := params.KW["unfold_read_specification"].(map[string]interface{})
|
||||
unfoldLimit := defaultWebSearchLimit
|
||||
if v, ok := params.KW["unfold_read_default_limit"].(float64); ok {
|
||||
unfoldLimit = int(v)
|
||||
}
|
||||
|
||||
// Parse original domain for combining with group domain
|
||||
origDomain := parseDomain(params.Args)
|
||||
if origDomain == nil {
|
||||
if dr, ok := params.KW["domain"].([]interface{}); ok && len(dr) > 0 {
|
||||
origDomain = parseDomain([]interface{}{dr})
|
||||
}
|
||||
}
|
||||
|
||||
modelName := rs.ModelDef().Name()
|
||||
maxUnfolded := 10
|
||||
unfolded := 0
|
||||
for _, g := range groups {
|
||||
if unfolded >= maxUnfolded {
|
||||
break
|
||||
}
|
||||
gm := g.(map[string]interface{})
|
||||
fold, _ := gm["__fold"].(bool)
|
||||
count, _ := gm["__count"].(int64)
|
||||
// Skip folded, empty, and groups with false/nil M2O value
|
||||
// Mirrors: odoo/addons/web/models/models.py _open_groups() fold checks
|
||||
if fold || count == 0 {
|
||||
continue
|
||||
}
|
||||
// For M2O groupby: skip groups where the value is false (unset M2O)
|
||||
if len(groupby) > 0 {
|
||||
gbVal := gm[groupby[0]]
|
||||
if gbVal == nil || gbVal == false {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Build combined domain: original + group extra domain
|
||||
var combinedDomain orm.Domain
|
||||
if origDomain != nil {
|
||||
combinedDomain = append(combinedDomain, origDomain...)
|
||||
}
|
||||
if extraDom, ok := gm["__extra_domain"].([]interface{}); ok && len(extraDom) > 0 {
|
||||
groupDomain := parseDomain([]interface{}{extraDom})
|
||||
combinedDomain = append(combinedDomain, groupDomain...)
|
||||
}
|
||||
|
||||
found, err := rs.Env().Model(modelName).Search(combinedDomain, orm.SearchOpts{Limit: unfoldLimit})
|
||||
if err != nil || found.IsEmpty() {
|
||||
gm["__records"] = []orm.Values{}
|
||||
unfolded++
|
||||
continue
|
||||
}
|
||||
|
||||
fields := specToFields(unfoldReadSpec)
|
||||
if len(fields) == 0 {
|
||||
fields = []string{"id"}
|
||||
}
|
||||
hasID := false
|
||||
for _, f := range fields {
|
||||
if f == "id" {
|
||||
hasID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasID {
|
||||
fields = append([]string{"id"}, fields...)
|
||||
}
|
||||
|
||||
records, err := found.Read(fields)
|
||||
if err != nil {
|
||||
gm["__records"] = []orm.Values{}
|
||||
unfolded++
|
||||
continue
|
||||
}
|
||||
formatRecordsForWeb(rs.Env(), modelName, records, unfoldReadSpec)
|
||||
gm["__records"] = records
|
||||
unfolded++
|
||||
}
|
||||
}
|
||||
|
||||
// web_read_group: also get total group count (without limit/offset)
|
||||
totalLen := len(results)
|
||||
if opts.Limit > 0 || opts.Offset > 0 {
|
||||
@@ -470,6 +575,203 @@ func (s *Server) handleReadGroup(rs *orm.Recordset, params CallKWParams) (interf
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// handleReadProgressBar returns per-group counts for a progress bar field.
|
||||
// Mirrors: odoo/orm/models.py BaseModel._read_progress_bar()
|
||||
//
|
||||
// Called by the kanban view to render colored progress bars per column.
|
||||
// Input (via KW):
|
||||
//
|
||||
// domain: search filter
|
||||
// group_by: field to group columns by (e.g. "stage_id")
|
||||
// progress_bar: {field: "kanban_state", colors: {"done": "success", ...}}
|
||||
//
|
||||
// Output:
|
||||
//
|
||||
// {groupByValue: {pbValue: count, ...}, ...}
|
||||
//
|
||||
// Where groupByValue is the raw DB value (integer ID for M2O, string for
|
||||
// selection, "True"/"False" for boolean).
|
||||
func (s *Server) handleReadProgressBar(rs *orm.Recordset, params CallKWParams) (interface{}, *RPCError) {
|
||||
// Parse domain from KW
|
||||
domain := parseDomain(params.Args)
|
||||
if domain == nil {
|
||||
if dr, ok := params.KW["domain"].([]interface{}); ok && len(dr) > 0 {
|
||||
domain = parseDomain([]interface{}{dr})
|
||||
}
|
||||
}
|
||||
|
||||
// Parse group_by (single string)
|
||||
groupBy := ""
|
||||
if v, ok := params.KW["group_by"].(string); ok {
|
||||
groupBy = v
|
||||
}
|
||||
|
||||
// Parse progress_bar map
|
||||
progressBar, _ := params.KW["progress_bar"].(map[string]interface{})
|
||||
pbField, _ := progressBar["field"].(string)
|
||||
|
||||
if groupBy == "" || pbField == "" {
|
||||
return map[string]interface{}{}, nil
|
||||
}
|
||||
|
||||
// Use ReadGroup with two groupby levels: [groupBy, pbField]
|
||||
results, err := rs.ReadGroup(domain, []string{groupBy, pbField}, []string{"__count"})
|
||||
if err != nil {
|
||||
return map[string]interface{}{}, nil
|
||||
}
|
||||
|
||||
// Determine field types for key formatting
|
||||
m := rs.ModelDef()
|
||||
gbField := m.GetField(groupBy)
|
||||
pbFieldDef := m.GetField(pbField)
|
||||
|
||||
// Build nested map: {groupByValue: {pbValue: count}}
|
||||
data := make(map[string]interface{})
|
||||
|
||||
// Collect all known progress bar values (from colors) so we initialize zeros
|
||||
pbColors, _ := progressBar["colors"].(map[string]interface{})
|
||||
|
||||
for _, r := range results {
|
||||
// Format the group-by key
|
||||
gbVal := r.GroupValues[groupBy]
|
||||
gbKey := formatProgressBarKey(gbVal, gbField)
|
||||
|
||||
// Format the progress bar value
|
||||
pbVal := r.GroupValues[pbField]
|
||||
pbKey := formatProgressBarValue(pbVal, pbFieldDef)
|
||||
|
||||
// Initialize group entry with zero counts if first time
|
||||
if _, exists := data[gbKey]; !exists {
|
||||
entry := make(map[string]interface{})
|
||||
for colorKey := range pbColors {
|
||||
entry[colorKey] = 0
|
||||
}
|
||||
data[gbKey] = entry
|
||||
}
|
||||
|
||||
// Add count
|
||||
entry := data[gbKey].(map[string]interface{})
|
||||
existing, _ := entry[pbKey].(int)
|
||||
entry[pbKey] = existing + int(r.Count)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// formatProgressBarKey formats a group-by value as the string key expected
|
||||
// by the frontend progress bar.
|
||||
// - M2O: integer ID (as string)
|
||||
// - Boolean: "True" / "False"
|
||||
// - nil/false: "False"
|
||||
// - Other: value as string
|
||||
func formatProgressBarKey(val interface{}, f *orm.Field) string {
|
||||
if val == nil || val == false {
|
||||
return "False"
|
||||
}
|
||||
|
||||
// M2O: ReadGroup resolves to [id, name] pair — use the id
|
||||
if f != nil && f.Type == orm.TypeMany2one {
|
||||
switch v := val.(type) {
|
||||
case []interface{}:
|
||||
if len(v) > 0 {
|
||||
return fmt.Sprintf("%v", v[0])
|
||||
}
|
||||
return "False"
|
||||
case int64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case float64:
|
||||
return fmt.Sprintf("%d", int64(v))
|
||||
case int:
|
||||
return fmt.Sprintf("%d", v)
|
||||
}
|
||||
}
|
||||
|
||||
// Boolean
|
||||
if f != nil && f.Type == orm.TypeBoolean {
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
return "True"
|
||||
}
|
||||
return "False"
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v", val)
|
||||
}
|
||||
|
||||
// formatProgressBarValue formats a progress bar field value as a string key.
|
||||
// Selection fields use the raw value (e.g. "done", "blocked").
|
||||
// Boolean fields use "True"/"False".
|
||||
func formatProgressBarValue(val interface{}, f *orm.Field) string {
|
||||
if val == nil || val == false {
|
||||
return "False"
|
||||
}
|
||||
if f != nil && f.Type == orm.TypeBoolean {
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
return "True"
|
||||
}
|
||||
return "False"
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%v", val)
|
||||
}
|
||||
|
||||
// addFoldInfo reads the "fold" boolean from the comodel records referenced
|
||||
// by each group and sets __fold on the group maps accordingly.
|
||||
func addFoldInfo(env *orm.Environment, comodel string, groupbySpec string, groups []interface{}) {
|
||||
// Collect IDs from group values (M2O pairs like [id, name])
|
||||
var ids []int64
|
||||
for _, g := range groups {
|
||||
gm := g.(map[string]interface{})
|
||||
val := gm[groupbySpec]
|
||||
if pair, ok := val.([]interface{}); ok && len(pair) >= 1 {
|
||||
if id, ok := orm.ToRecordID(pair[0]); ok && id > 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
// All groups have false/empty value — fold them by default
|
||||
for _, g := range groups {
|
||||
gm := g.(map[string]interface{})
|
||||
gm["__fold"] = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Read fold values from comodel
|
||||
rs := env.Model(comodel).Browse(ids...)
|
||||
records, err := rs.Read([]string{"id", "fold"})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Build fold map
|
||||
foldMap := make(map[int64]bool)
|
||||
for _, rec := range records {
|
||||
id, _ := orm.ToRecordID(rec["id"])
|
||||
fold, _ := rec["fold"].(bool)
|
||||
foldMap[id] = fold
|
||||
}
|
||||
|
||||
// Apply to groups
|
||||
for _, g := range groups {
|
||||
gm := g.(map[string]interface{})
|
||||
val := gm[groupbySpec]
|
||||
if pair, ok := val.([]interface{}); ok && len(pair) >= 1 {
|
||||
if id, ok := orm.ToRecordID(pair[0]); ok {
|
||||
gm["__fold"] = foldMap[id]
|
||||
}
|
||||
} else {
|
||||
// false/empty group value
|
||||
gm["__fold"] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// formatDateFields converts date/datetime values to Odoo's expected string format.
|
||||
func formatDateFields(model string, records []orm.Values) {
|
||||
m := orm.Registry.Get(model)
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -73,6 +74,14 @@ func (s *Server) handleWebClient(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if post-setup wizard is needed (first login, company not configured)
|
||||
if s.isPostSetupNeeded() {
|
||||
if sess := GetSession(r); sess != nil {
|
||||
http.Redirect(w, r, "/web/setup/wizard", http.StatusFound)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check authentication
|
||||
sess := GetSession(r)
|
||||
if sess == nil {
|
||||
@@ -141,7 +150,7 @@ func (s *Server) handleWebClient(w http.ResponseWriter, r *http.Request) {
|
||||
%s
|
||||
<script>
|
||||
var odoo = {
|
||||
csrf_token: "dummy",
|
||||
csrf_token: "%s",
|
||||
debug: "assets",
|
||||
__session_info__: %s,
|
||||
reloadMenus: function() {
|
||||
@@ -178,12 +187,18 @@ func (s *Server) handleWebClient(w http.ResponseWriter, r *http.Request) {
|
||||
%s</head>
|
||||
<body class="o_web_client">
|
||||
</body>
|
||||
</html>`, linkTags.String(), sessionInfoJSON, scriptTags.String())
|
||||
</html>`, linkTags.String(), sess.CSRFToken, sessionInfoJSON, scriptTags.String())
|
||||
}
|
||||
|
||||
// buildSessionInfo constructs the session_info JSON object expected by the webclient.
|
||||
// Mirrors: odoo/addons/web/models/ir_http.py session_info()
|
||||
func (s *Server) buildSessionInfo(sess *Session) map[string]interface{} {
|
||||
// Build allowed_company_ids from session (populated at login)
|
||||
allowedIDs := sess.AllowedCompanyIDs
|
||||
if len(allowedIDs) == 0 {
|
||||
allowedIDs = []int64{sess.CompanyID}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"session_id": sess.ID,
|
||||
"uid": sess.UID,
|
||||
@@ -194,7 +209,7 @@ func (s *Server) buildSessionInfo(sess *Session) map[string]interface{} {
|
||||
"user_context": map[string]interface{}{
|
||||
"lang": "en_US",
|
||||
"tz": "UTC",
|
||||
"allowed_company_ids": []int64{sess.CompanyID},
|
||||
"allowed_company_ids": allowedIDs,
|
||||
},
|
||||
"db": s.config.DBName,
|
||||
"registry_hash": fmt.Sprintf("odoo-go-%d", time.Now().Unix()),
|
||||
@@ -213,7 +228,7 @@ func (s *Server) buildSessionInfo(sess *Session) map[string]interface{} {
|
||||
"current_menu": 1,
|
||||
"support_url": "",
|
||||
"notification_type": "email",
|
||||
"display_switch_company_menu": false,
|
||||
"display_switch_company_menu": len(allowedIDs) > 1,
|
||||
"test_mode": false,
|
||||
"show_effect": true,
|
||||
"currencies": map[string]interface{}{
|
||||
@@ -226,20 +241,7 @@ func (s *Server) buildSessionInfo(sess *Session) map[string]interface{} {
|
||||
"lang": "en_US",
|
||||
"debug": "assets",
|
||||
},
|
||||
"user_companies": map[string]interface{}{
|
||||
"current_company": sess.CompanyID,
|
||||
"allowed_companies": map[string]interface{}{
|
||||
fmt.Sprintf("%d", sess.CompanyID): map[string]interface{}{
|
||||
"id": sess.CompanyID,
|
||||
"name": "My Company",
|
||||
"sequence": 10,
|
||||
"child_ids": []int64{},
|
||||
"parent_id": false,
|
||||
"currency_id": 1,
|
||||
},
|
||||
},
|
||||
"disallowed_ancestor_companies": map[string]interface{}{},
|
||||
},
|
||||
"user_companies": s.buildUserCompanies(sess.CompanyID, allowedIDs),
|
||||
"user_settings": map[string]interface{}{
|
||||
"id": 1,
|
||||
"user_id": map[string]interface{}{"id": sess.UID, "display_name": sess.Login},
|
||||
@@ -365,3 +367,105 @@ func (s *Server) handleTranslations(w http.ResponseWriter, r *http.Request) {
|
||||
"multi_lang": multiLang,
|
||||
})
|
||||
}
|
||||
|
||||
// buildUserCompanies queries company data and builds the user_companies dict
|
||||
// for the session_info response. Mirrors: odoo/addons/web/models/ir_http.py
|
||||
func (s *Server) buildUserCompanies(currentCompanyID int64, allowedIDs []int64) map[string]interface{} {
|
||||
allowedCompanies := make(map[string]interface{})
|
||||
|
||||
// Batch query all companies at once
|
||||
rows, err := s.pool.Query(context.Background(),
|
||||
`SELECT id, COALESCE(name, 'Company'), COALESCE(currency_id, 1)
|
||||
FROM res_company WHERE id = ANY($1)`, allowedIDs)
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var cid, currencyID int64
|
||||
var name string
|
||||
if rows.Scan(&cid, &name, ¤cyID) == nil {
|
||||
allowedCompanies[fmt.Sprintf("%d", cid)] = map[string]interface{}{
|
||||
"id": cid,
|
||||
"name": name,
|
||||
"sequence": 10,
|
||||
"child_ids": []int64{},
|
||||
"parent_id": false,
|
||||
"currency_id": currencyID,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback for any IDs not found in DB
|
||||
for _, cid := range allowedIDs {
|
||||
key := fmt.Sprintf("%d", cid)
|
||||
if _, exists := allowedCompanies[key]; !exists {
|
||||
allowedCompanies[key] = map[string]interface{}{
|
||||
"id": cid, "name": fmt.Sprintf("Company %d", cid),
|
||||
"sequence": 10, "child_ids": []int64{}, "parent_id": false, "currency_id": int64(1),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"current_company": currentCompanyID,
|
||||
"allowed_companies": allowedCompanies,
|
||||
"disallowed_ancestor_companies": map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
|
||||
// handleSwitchCompany switches the active company for the current session.
|
||||
func (s *Server) handleSwitchCompany(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
sess := GetSession(r)
|
||||
if sess == nil {
|
||||
s.writeJSONRPC(w, nil, nil, &RPCError{Code: 100, Message: "Not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var req JSONRPCRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.writeJSONRPC(w, nil, nil, &RPCError{Code: -32700, Message: "Parse error"})
|
||||
return
|
||||
}
|
||||
|
||||
var params struct {
|
||||
CompanyID int64 `json:"company_id"`
|
||||
}
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil || params.CompanyID == 0 {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "Invalid company_id"})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate company is in allowed list
|
||||
allowed := false
|
||||
for _, cid := range sess.AllowedCompanyIDs {
|
||||
if cid == params.CompanyID {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: 403, Message: "Company not in allowed list"})
|
||||
return
|
||||
}
|
||||
|
||||
// Update session
|
||||
sess.CompanyID = params.CompanyID
|
||||
|
||||
// Persist to DB
|
||||
if s.sessions.pool != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.sessions.pool.Exec(ctx,
|
||||
`UPDATE sessions SET company_id = $1 WHERE id = $2`, params.CompanyID, sess.ID)
|
||||
}
|
||||
|
||||
s.writeJSONRPC(w, req.ID, map[string]interface{}{
|
||||
"company_id": params.CompanyID,
|
||||
"result": true,
|
||||
}, nil)
|
||||
}
|
||||
|
||||
160
pkg/service/automation.go
Normal file
160
pkg/service/automation.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"odoo-go/pkg/orm"
|
||||
"odoo-go/pkg/tools"
|
||||
)
|
||||
|
||||
// RunAutomatedActions checks and executes server actions triggered by Create/Write/Unlink.
|
||||
// Called from the ORM after successful Create/Write/Unlink operations.
|
||||
// Mirrors: odoo/addons/base_automation/models/base_automation.py
|
||||
func RunAutomatedActions(env *orm.Environment, modelName, trigger string, recordIDs []int64) {
|
||||
if len(recordIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Look up the ir_model ID for this model
|
||||
var modelID int64
|
||||
err := env.Tx().QueryRow(env.Ctx(),
|
||||
`SELECT id FROM ir_model WHERE model = $1`, modelName).Scan(&modelID)
|
||||
if err != nil {
|
||||
return // Model not in ir_model — no actions possible
|
||||
}
|
||||
|
||||
// Find matching automated actions
|
||||
rows, err := env.Tx().Query(env.Ctx(),
|
||||
`SELECT id, state, COALESCE(update_field_id, ''), COALESCE(update_value, ''),
|
||||
COALESCE(email_to, ''), COALESCE(email_subject, ''), COALESCE(email_body, ''),
|
||||
COALESCE(filter_domain, '')
|
||||
FROM ir_act_server
|
||||
WHERE model_id = $1
|
||||
AND active = true
|
||||
AND trigger = $2
|
||||
ORDER BY sequence, id`, modelID, trigger)
|
||||
if err != nil {
|
||||
log.Printf("automation: query error for %s/%s: %v", modelName, trigger, err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type action struct {
|
||||
id int64
|
||||
state string
|
||||
updateField string
|
||||
updateValue string
|
||||
emailTo string
|
||||
emailSubject string
|
||||
emailBody string
|
||||
filterDomain string
|
||||
}
|
||||
|
||||
var actions []action
|
||||
for rows.Next() {
|
||||
var a action
|
||||
if err := rows.Scan(&a.id, &a.state, &a.updateField, &a.updateValue,
|
||||
&a.emailTo, &a.emailSubject, &a.emailBody, &a.filterDomain); err != nil {
|
||||
continue
|
||||
}
|
||||
actions = append(actions, a)
|
||||
}
|
||||
|
||||
if len(actions) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, a := range actions {
|
||||
switch a.state {
|
||||
case "object_write":
|
||||
executeObjectWrite(env, modelName, recordIDs, a.updateField, a.updateValue)
|
||||
case "email":
|
||||
executeEmailAction(env, modelName, recordIDs, a.emailTo, a.emailSubject, a.emailBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// executeObjectWrite updates a field on the triggered records.
|
||||
func executeObjectWrite(env *orm.Environment, modelName string, recordIDs []int64, fieldName, value string) {
|
||||
if fieldName == "" {
|
||||
return
|
||||
}
|
||||
tableName := strings.ReplaceAll(modelName, ".", "_")
|
||||
for _, id := range recordIDs {
|
||||
_, err := env.Tx().Exec(env.Ctx(),
|
||||
fmt.Sprintf(`UPDATE %s SET %s = $1 WHERE id = $2`,
|
||||
pgx.Identifier{tableName}.Sanitize(),
|
||||
pgx.Identifier{fieldName}.Sanitize()),
|
||||
value, id)
|
||||
if err != nil {
|
||||
log.Printf("automation: object_write error %s.%s on %d: %v", modelName, fieldName, id, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// executeEmailAction sends an email for each triggered record.
|
||||
func executeEmailAction(env *orm.Environment, modelName string, recordIDs []int64, emailToField, subject, bodyTemplate string) {
|
||||
if emailToField == "" {
|
||||
return
|
||||
}
|
||||
|
||||
cfg := tools.LoadSMTPConfig()
|
||||
if cfg.Host == "" {
|
||||
return
|
||||
}
|
||||
|
||||
tableName := strings.ReplaceAll(modelName, ".", "_")
|
||||
|
||||
for _, id := range recordIDs {
|
||||
// Resolve email address from the record
|
||||
var email string
|
||||
err := env.Tx().QueryRow(env.Ctx(),
|
||||
fmt.Sprintf(`SELECT COALESCE(%s, '') FROM %s WHERE id = $1`,
|
||||
pgx.Identifier{emailToField}.Sanitize(),
|
||||
pgx.Identifier{tableName}.Sanitize()),
|
||||
id).Scan(&email)
|
||||
if err != nil || email == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Simple template: replace {{field}} with record values
|
||||
body := bodyTemplate
|
||||
if strings.Contains(body, "{{") {
|
||||
body = resolveTemplate(env, tableName, id, body)
|
||||
}
|
||||
|
||||
if err := tools.SendEmail(cfg, email, subject, body); err != nil {
|
||||
log.Printf("automation: email error to %s for %s/%d: %v", email, modelName, id, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveTemplate replaces {{field_name}} placeholders with actual record values.
|
||||
func resolveTemplate(env *orm.Environment, tableName string, recordID int64, template string) string {
|
||||
result := template
|
||||
for {
|
||||
start := strings.Index(result, "{{")
|
||||
if start == -1 {
|
||||
break
|
||||
}
|
||||
end := strings.Index(result[start:], "}}")
|
||||
if end == -1 {
|
||||
break
|
||||
}
|
||||
fieldName := strings.TrimSpace(result[start+2 : start+end])
|
||||
var val string
|
||||
err := env.Tx().QueryRow(env.Ctx(),
|
||||
fmt.Sprintf(`SELECT COALESCE(CAST(%s AS TEXT), '') FROM %s WHERE id = $1`,
|
||||
pgx.Identifier{fieldName}.Sanitize(),
|
||||
pgx.Identifier{tableName}.Sanitize()),
|
||||
recordID).Scan(&val)
|
||||
if err != nil {
|
||||
val = ""
|
||||
}
|
||||
result = result[:start] + val + result[start+end+2:]
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -2,57 +2,70 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"odoo-go/pkg/orm"
|
||||
)
|
||||
|
||||
// CronJob defines a scheduled task.
|
||||
type CronJob struct {
|
||||
Name string
|
||||
Interval time.Duration
|
||||
Handler func(ctx context.Context, pool *pgxpool.Pool) error
|
||||
running bool
|
||||
const (
|
||||
cronPollInterval = 60 * time.Second
|
||||
maxFailureCount = 5
|
||||
)
|
||||
|
||||
// cronJob holds a single scheduled action loaded from the ir_cron table.
|
||||
type cronJob struct {
|
||||
ID int64
|
||||
Name string
|
||||
ModelName string
|
||||
MethodName string
|
||||
UserID int64
|
||||
IntervalNumber int
|
||||
IntervalType string
|
||||
NumberCall int
|
||||
NextCall time.Time
|
||||
}
|
||||
|
||||
// CronScheduler manages periodic jobs.
|
||||
// CronScheduler polls ir_cron and executes ready jobs.
|
||||
// Mirrors: odoo/addons/base/models/ir_cron.py IrCron._process_jobs()
|
||||
type CronScheduler struct {
|
||||
jobs []*CronJob
|
||||
mu sync.Mutex
|
||||
pool *pgxpool.Pool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewCronScheduler creates a new scheduler.
|
||||
func NewCronScheduler() *CronScheduler {
|
||||
// NewCronScheduler creates a DB-driven cron scheduler.
|
||||
func NewCronScheduler(pool *pgxpool.Pool) *CronScheduler {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &CronScheduler{ctx: ctx, cancel: cancel}
|
||||
return &CronScheduler{pool: pool, ctx: ctx, cancel: cancel}
|
||||
}
|
||||
|
||||
// Register adds a job to the scheduler.
|
||||
func (s *CronScheduler) Register(job *CronJob) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.jobs = append(s.jobs, job)
|
||||
// Start begins the polling loop in a background goroutine.
|
||||
func (s *CronScheduler) Start() {
|
||||
s.wg.Add(1)
|
||||
go s.pollLoop()
|
||||
log.Println("cron: scheduler started")
|
||||
}
|
||||
|
||||
// Start begins running all registered jobs.
|
||||
func (s *CronScheduler) Start(pool *pgxpool.Pool) {
|
||||
for _, job := range s.jobs {
|
||||
go s.runJob(job, pool)
|
||||
}
|
||||
log.Printf("cron: started %d jobs", len(s.jobs))
|
||||
}
|
||||
|
||||
// Stop cancels all running jobs.
|
||||
// Stop cancels the polling loop and waits for completion.
|
||||
func (s *CronScheduler) Stop() {
|
||||
s.cancel()
|
||||
s.wg.Wait()
|
||||
log.Println("cron: scheduler stopped")
|
||||
}
|
||||
|
||||
func (s *CronScheduler) runJob(job *CronJob, pool *pgxpool.Pool) {
|
||||
ticker := time.NewTicker(job.Interval)
|
||||
func (s *CronScheduler) pollLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
// Run once immediately, then on ticker
|
||||
s.processJobs()
|
||||
|
||||
ticker := time.NewTicker(cronPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
@@ -60,9 +73,200 @@ func (s *CronScheduler) runJob(job *CronJob, pool *pgxpool.Pool) {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := job.Handler(s.ctx, pool); err != nil {
|
||||
log.Printf("cron: %s error: %v", job.Name, err)
|
||||
}
|
||||
s.processJobs()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processJobs queries all ready cron jobs and processes them one by one.
|
||||
func (s *CronScheduler) processJobs() {
|
||||
rows, err := s.pool.Query(s.ctx, `
|
||||
SELECT id, name, model_name, method_name, user_id,
|
||||
interval_number, interval_type, numbercall, nextcall
|
||||
FROM ir_cron
|
||||
WHERE active = true AND nextcall <= now()
|
||||
ORDER BY priority, id
|
||||
`)
|
||||
if err != nil {
|
||||
log.Printf("cron: query error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var jobs []cronJob
|
||||
for rows.Next() {
|
||||
var j cronJob
|
||||
var modelName, methodName *string // nullable
|
||||
if err := rows.Scan(&j.ID, &j.Name, &modelName, &methodName, &j.UserID,
|
||||
&j.IntervalNumber, &j.IntervalType, &j.NumberCall, &j.NextCall); err != nil {
|
||||
log.Printf("cron: scan error: %v", err)
|
||||
continue
|
||||
}
|
||||
if modelName != nil {
|
||||
j.ModelName = *modelName
|
||||
}
|
||||
if methodName != nil {
|
||||
j.MethodName = *methodName
|
||||
}
|
||||
jobs = append(jobs, j)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
for _, job := range jobs {
|
||||
s.processOneJob(job)
|
||||
}
|
||||
}
|
||||
|
||||
// processOneJob acquires a row-level lock and executes a single cron job.
|
||||
func (s *CronScheduler) processOneJob(job cronJob) {
|
||||
tx, err := s.pool.Begin(s.ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(s.ctx)
|
||||
|
||||
// Try to acquire the job with FOR NO KEY UPDATE SKIP LOCKED
|
||||
var lockedID int64
|
||||
err = tx.QueryRow(s.ctx, `
|
||||
SELECT id FROM ir_cron
|
||||
WHERE id = $1 AND active = true AND nextcall <= now()
|
||||
FOR NO KEY UPDATE SKIP LOCKED
|
||||
`, job.ID).Scan(&lockedID)
|
||||
if err != nil {
|
||||
// Job already taken by another worker or not ready
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("cron: executing %q (id=%d)", job.Name, job.ID)
|
||||
|
||||
execErr := s.executeJob(job)
|
||||
|
||||
now := time.Now()
|
||||
nextCall := calculateNextCall(now, job.IntervalNumber, job.IntervalType)
|
||||
|
||||
if execErr != nil {
|
||||
log.Printf("cron: %q failed: %v", job.Name, execErr)
|
||||
|
||||
// Update failure count, set first_failure_date if not already set
|
||||
if _, err := tx.Exec(s.ctx, `
|
||||
UPDATE ir_cron SET
|
||||
failure_count = failure_count + 1,
|
||||
first_failure_date = COALESCE(first_failure_date, $1),
|
||||
lastcall = $1,
|
||||
nextcall = $2
|
||||
WHERE id = $3
|
||||
`, now, nextCall, job.ID); err != nil {
|
||||
log.Printf("cron: failed to update failure count for %q: %v", job.Name, err)
|
||||
}
|
||||
|
||||
// Deactivate if too many consecutive failures
|
||||
if _, err := tx.Exec(s.ctx, `
|
||||
UPDATE ir_cron SET active = false
|
||||
WHERE id = $1 AND failure_count >= $2
|
||||
`, job.ID, maxFailureCount); err != nil {
|
||||
log.Printf("cron: failed to deactivate %q: %v", job.Name, err)
|
||||
}
|
||||
} else {
|
||||
log.Printf("cron: %q completed successfully", job.Name)
|
||||
|
||||
if job.NumberCall > 0 {
|
||||
// Finite run count: decrement
|
||||
newNumberCall := job.NumberCall - 1
|
||||
if newNumberCall <= 0 {
|
||||
if _, err := tx.Exec(s.ctx, `
|
||||
UPDATE ir_cron SET active = false, lastcall = $1, nextcall = $2,
|
||||
failure_count = 0, first_failure_date = NULL, numbercall = 0
|
||||
WHERE id = $3
|
||||
`, now, nextCall, job.ID); err != nil {
|
||||
log.Printf("cron: failed to update job %q: %v", job.Name, err)
|
||||
}
|
||||
} else {
|
||||
if _, err := tx.Exec(s.ctx, `
|
||||
UPDATE ir_cron SET lastcall = $1, nextcall = $2,
|
||||
failure_count = 0, first_failure_date = NULL, numbercall = $3
|
||||
WHERE id = $4
|
||||
`, now, nextCall, newNumberCall, job.ID); err != nil {
|
||||
log.Printf("cron: failed to update job %q: %v", job.Name, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// numbercall <= 0 means infinite runs
|
||||
if _, err := tx.Exec(s.ctx, `
|
||||
UPDATE ir_cron SET lastcall = $1, nextcall = $2,
|
||||
failure_count = 0, first_failure_date = NULL
|
||||
WHERE id = $3
|
||||
`, now, nextCall, job.ID); err != nil {
|
||||
log.Printf("cron: failed to update job %q: %v", job.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(s.ctx); err != nil {
|
||||
log.Printf("cron: commit error for %q: %v", job.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// executeJob looks up the target method in orm.Registry and calls it.
|
||||
func (s *CronScheduler) executeJob(job cronJob) error {
|
||||
if job.ModelName == "" || job.MethodName == "" {
|
||||
return fmt.Errorf("cron %q: model_name or method_name not set", job.Name)
|
||||
}
|
||||
|
||||
model := orm.Registry.Get(job.ModelName)
|
||||
if model == nil {
|
||||
return fmt.Errorf("cron %q: model %q not found", job.Name, job.ModelName)
|
||||
}
|
||||
if model.Methods == nil {
|
||||
return fmt.Errorf("cron %q: model %q has no methods", job.Name, job.ModelName)
|
||||
}
|
||||
method, ok := model.Methods[job.MethodName]
|
||||
if !ok {
|
||||
return fmt.Errorf("cron %q: method %q not found on %q", job.Name, job.MethodName, job.ModelName)
|
||||
}
|
||||
|
||||
// Create ORM environment for job execution
|
||||
uid := job.UserID
|
||||
if uid == 0 {
|
||||
return fmt.Errorf("cron %q: user_id not set, refusing to run as admin", job.Name)
|
||||
}
|
||||
|
||||
env, err := orm.NewEnvironment(s.ctx, orm.EnvConfig{
|
||||
Pool: s.pool,
|
||||
UID: uid,
|
||||
Context: map[string]interface{}{
|
||||
"lastcall": job.NextCall,
|
||||
"cron_id": job.ID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("cron %q: env error: %w", job.Name, err)
|
||||
}
|
||||
defer env.Close()
|
||||
|
||||
// Call the method on an empty recordset of the target model
|
||||
_, err = method(env.Model(job.ModelName))
|
||||
if err != nil {
|
||||
env.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return env.Commit()
|
||||
}
|
||||
|
||||
// calculateNextCall computes the next execution time based on interval.
|
||||
// Mirrors: odoo/addons/base/models/ir_cron.py _intervalTypes
|
||||
func calculateNextCall(from time.Time, number int, intervalType string) time.Time {
|
||||
switch intervalType {
|
||||
case "minutes":
|
||||
return from.Add(time.Duration(number) * time.Minute)
|
||||
case "hours":
|
||||
return from.Add(time.Duration(number) * time.Hour)
|
||||
case "days":
|
||||
return from.AddDate(0, 0, number)
|
||||
case "weeks":
|
||||
return from.AddDate(0, 0, number*7)
|
||||
case "months":
|
||||
return from.AddDate(0, number, 0)
|
||||
default:
|
||||
return from.Add(time.Duration(number) * time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -330,6 +330,8 @@ func SeedWithSetup(ctx context.Context, pool *pgxpool.Pool, cfg SetupConfig) err
|
||||
VALUES (1, 1, true, 1, 1) ON CONFLICT (id) DO NOTHING`)
|
||||
})
|
||||
|
||||
safeExec(ctx, tx, "base_groups", func() { seedBaseGroups(ctx, tx) })
|
||||
safeExec(ctx, tx, "acl_rules", func() { seedACLRules(ctx, tx) })
|
||||
safeExec(ctx, tx, "system_params", func() { seedSystemParams(ctx, tx) })
|
||||
safeExec(ctx, tx, "languages", func() { seedLanguages(ctx, tx) })
|
||||
safeExec(ctx, tx, "translations", func() { seedTranslations(ctx, tx) })
|
||||
@@ -1676,3 +1678,136 @@ func generateUUID() string {
|
||||
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
|
||||
}
|
||||
|
||||
// seedBaseGroups creates the base security groups and their XML IDs.
|
||||
// Mirrors: odoo/addons/base/security/base_groups.xml
|
||||
func seedBaseGroups(ctx context.Context, tx pgx.Tx) {
|
||||
log.Println("db: seeding base security groups...")
|
||||
|
||||
type groupDef struct {
|
||||
id int64
|
||||
name string
|
||||
xmlID string
|
||||
}
|
||||
groups := []groupDef{
|
||||
{1, "Internal User", "group_user"},
|
||||
{2, "Settings", "group_system"},
|
||||
{3, "Access Rights", "group_erp_manager"},
|
||||
{4, "Allow Export", "group_allow_export"},
|
||||
{5, "Portal", "group_portal"},
|
||||
{6, "Public", "group_public"},
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
tx.Exec(ctx, `INSERT INTO res_groups (id, name)
|
||||
VALUES ($1, $2) ON CONFLICT (id) DO NOTHING`, g.id, g.name)
|
||||
tx.Exec(ctx, `INSERT INTO ir_model_data (module, name, model, res_id)
|
||||
VALUES ('base', $1, 'res.groups', $2) ON CONFLICT DO NOTHING`, g.xmlID, g.id)
|
||||
}
|
||||
|
||||
// Add admin user (uid=1) to all groups
|
||||
for _, g := range groups {
|
||||
tx.Exec(ctx, `INSERT INTO res_groups_res_users_rel (res_groups_id, res_users_id)
|
||||
VALUES ($1, 1) ON CONFLICT DO NOTHING`, g.id)
|
||||
}
|
||||
}
|
||||
|
||||
// seedACLRules creates access control entries for ALL registered models.
|
||||
// Categorizes models into security tiers and assigns appropriate permissions.
|
||||
// Mirrors: odoo/addons/base/security/ir.model.access.csv + per-module CSVs
|
||||
func seedACLRules(ctx context.Context, tx pgx.Tx) {
|
||||
log.Println("db: seeding ACL rules for all models...")
|
||||
|
||||
// Resolve group IDs
|
||||
var groupSystem, groupUser int64
|
||||
err := tx.QueryRow(ctx,
|
||||
`SELECT g.id FROM res_groups g
|
||||
JOIN ir_model_data imd ON imd.res_id = g.id AND imd.model = 'res.groups'
|
||||
WHERE imd.module = 'base' AND imd.name = 'group_system'`).Scan(&groupSystem)
|
||||
if err != nil {
|
||||
log.Printf("db: cannot find group_system, skipping ACL seeding: %v", err)
|
||||
return
|
||||
}
|
||||
err = tx.QueryRow(ctx,
|
||||
`SELECT g.id FROM res_groups g
|
||||
JOIN ir_model_data imd ON imd.res_id = g.id AND imd.model = 'res.groups'
|
||||
WHERE imd.module = 'base' AND imd.name = 'group_user'`).Scan(&groupUser)
|
||||
if err != nil {
|
||||
log.Printf("db: cannot find group_user, skipping ACL seeding: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// ── Security Tiers ──────────────────────────────────────────────
|
||||
// Tier 1: System-only — only group_system gets full access
|
||||
systemOnly := map[string]bool{
|
||||
"ir.cron": true, "ir.rule": true, "ir.model.access": true,
|
||||
}
|
||||
|
||||
// Tier 2: Admin-only — group_user=read, group_system=full
|
||||
adminOnly := map[string]bool{
|
||||
"ir.model": true, "ir.model.fields": true, "ir.model.data": true,
|
||||
"ir.module.category": true, "ir.actions.server": true, "ir.sequence": true,
|
||||
"ir.logging": true, "ir.config_parameter": true, "ir.default": true,
|
||||
"ir.translation": true, "ir.actions.report": true, "report.paperformat": true,
|
||||
"res.config.settings": true,
|
||||
}
|
||||
|
||||
// Tier 3: Read-only for users — group_user=read, group_system=full
|
||||
readOnly := map[string]bool{
|
||||
"res.currency": true, "res.currency.rate": true,
|
||||
"res.country": true, "res.country.state": true, "res.country.group": true,
|
||||
"res.lang": true, "uom.category": true, "uom.uom": true,
|
||||
"product.category": true, "product.removal": true,
|
||||
"account.account.tag": true, "account.group": true,
|
||||
"account.tax.group": true, "account.tax.repartition.line": true,
|
||||
}
|
||||
|
||||
// Everything else → Tier 4: Standard user (group_user=full, group_system=full)
|
||||
|
||||
// Helper to insert an ACL rule
|
||||
insertACL := func(modelID int64, modelName string, groupID int64, suffix string, read, write, create, unlink bool) {
|
||||
aclName := "access_" + strings.ReplaceAll(modelName, ".", "_") + "_" + suffix
|
||||
tx.Exec(ctx, `
|
||||
INSERT INTO ir_model_access (name, model_id, group_id, perm_read, perm_write, perm_create, perm_unlink, active)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, true)
|
||||
ON CONFLICT DO NOTHING`,
|
||||
aclName, modelID, groupID, read, write, create, unlink)
|
||||
}
|
||||
|
||||
// Iterate ALL registered models
|
||||
allModels := orm.Registry.Models()
|
||||
seeded := 0
|
||||
for _, m := range allModels {
|
||||
modelName := m.Name()
|
||||
if m.IsAbstract() {
|
||||
continue // Abstract models have no table → no ACL needed
|
||||
}
|
||||
|
||||
// Look up ir_model ID
|
||||
var modelID int64
|
||||
err := tx.QueryRow(ctx,
|
||||
"SELECT id FROM ir_model WHERE model = $1", modelName).Scan(&modelID)
|
||||
if err != nil {
|
||||
continue // Not yet in ir_model — will be seeded on next restart
|
||||
}
|
||||
|
||||
if systemOnly[modelName] {
|
||||
// Tier 1: only group_system full access
|
||||
insertACL(modelID, modelName, groupSystem, "system", true, true, true, true)
|
||||
} else if adminOnly[modelName] {
|
||||
// Tier 2: group_user=read, group_system=full
|
||||
insertACL(modelID, modelName, groupUser, "user_read", true, false, false, false)
|
||||
insertACL(modelID, modelName, groupSystem, "system", true, true, true, true)
|
||||
} else if readOnly[modelName] {
|
||||
// Tier 3: group_user=read, group_system=full
|
||||
insertACL(modelID, modelName, groupUser, "user_read", true, false, false, false)
|
||||
insertACL(modelID, modelName, groupSystem, "system", true, true, true, true)
|
||||
} else {
|
||||
// Tier 4: group_user=full, group_system=full
|
||||
insertACL(modelID, modelName, groupUser, "user", true, true, true, true)
|
||||
insertACL(modelID, modelName, groupSystem, "system", true, true, true, true)
|
||||
}
|
||||
seeded++
|
||||
}
|
||||
log.Printf("db: seeded ACL rules for %d models", seeded)
|
||||
}
|
||||
|
||||
|
||||
255
pkg/service/fetchmail.go
Normal file
255
pkg/service/fetchmail.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/emersion/go-imap/v2"
|
||||
"github.com/emersion/go-imap/v2/imapclient"
|
||||
gomessage "github.com/emersion/go-message"
|
||||
_ "github.com/emersion/go-message/charset"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// FetchmailConfig holds IMAP server configuration.
|
||||
type FetchmailConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
UseTLS bool
|
||||
Folder string
|
||||
}
|
||||
|
||||
// LoadFetchmailConfig loads IMAP settings from environment variables.
|
||||
func LoadFetchmailConfig() *FetchmailConfig {
|
||||
cfg := &FetchmailConfig{
|
||||
Port: 993,
|
||||
UseTLS: true,
|
||||
Folder: "INBOX",
|
||||
}
|
||||
cfg.Host = os.Getenv("IMAP_HOST")
|
||||
cfg.User = os.Getenv("IMAP_USER")
|
||||
cfg.Password = os.Getenv("IMAP_PASSWORD")
|
||||
if v := os.Getenv("IMAP_FOLDER"); v != "" {
|
||||
cfg.Folder = v
|
||||
}
|
||||
if os.Getenv("IMAP_TLS") == "false" {
|
||||
cfg.UseTLS = false
|
||||
if cfg.Port == 993 {
|
||||
cfg.Port = 143
|
||||
}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// FetchAndProcessEmails connects to IMAP, fetches unseen emails, and creates
|
||||
// mail.message records in the database. Matches emails to existing threads
|
||||
// via In-Reply-To/References headers.
|
||||
// Mirrors: odoo/addons/fetchmail/models/fetchmail.py fetch_mail()
|
||||
func FetchAndProcessEmails(ctx context.Context, pool *pgxpool.Pool) error {
|
||||
cfg := LoadFetchmailConfig()
|
||||
if cfg.Host == "" {
|
||||
return nil // IMAP not configured
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
|
||||
|
||||
var c *imapclient.Client
|
||||
var err error
|
||||
if cfg.UseTLS {
|
||||
c, err = imapclient.DialTLS(addr, nil)
|
||||
} else {
|
||||
c, err = imapclient.DialInsecure(addr, nil)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetchmail: connect to %s: %w", addr, err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
if err := c.Login(cfg.User, cfg.Password).Wait(); err != nil {
|
||||
return fmt.Errorf("fetchmail: login as %s: %w", cfg.User, err)
|
||||
}
|
||||
|
||||
if _, err := c.Select(cfg.Folder, nil).Wait(); err != nil {
|
||||
return fmt.Errorf("fetchmail: select %s: %w", cfg.Folder, err)
|
||||
}
|
||||
|
||||
// Search unseen
|
||||
criteria := &imap.SearchCriteria{
|
||||
NotFlag: []imap.Flag{imap.FlagSeen},
|
||||
}
|
||||
searchData, err := c.Search(criteria, nil).Wait()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetchmail: search: %w", err)
|
||||
}
|
||||
|
||||
seqSet := searchData.All
|
||||
if seqSet == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fetch envelope + body
|
||||
fetchOpts := &imap.FetchOptions{
|
||||
Envelope: true,
|
||||
BodySection: []*imap.FetchItemBodySection{{}},
|
||||
}
|
||||
msgs, err := c.Fetch(seqSet, fetchOpts).Collect()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetchmail: fetch: %w", err)
|
||||
}
|
||||
|
||||
var processed int
|
||||
for _, msg := range msgs {
|
||||
if err := processOneEmail(ctx, pool, msg); err != nil {
|
||||
log.Printf("fetchmail: process error: %v", err)
|
||||
continue
|
||||
}
|
||||
processed++
|
||||
}
|
||||
|
||||
// Mark as seen
|
||||
if processed > 0 {
|
||||
storeFlags := &imap.StoreFlags{
|
||||
Op: imap.StoreFlagsAdd,
|
||||
Flags: []imap.Flag{imap.FlagSeen},
|
||||
}
|
||||
if _, err := c.Store(seqSet, storeFlags, nil).Collect(); err != nil {
|
||||
log.Printf("fetchmail: mark seen error: %v", err)
|
||||
}
|
||||
log.Printf("fetchmail: processed %d new emails", processed)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processOneEmail(ctx context.Context, pool *pgxpool.Pool, buf *imapclient.FetchMessageBuffer) error {
|
||||
env := buf.Envelope
|
||||
if env == nil {
|
||||
return fmt.Errorf("no envelope")
|
||||
}
|
||||
|
||||
subject := env.Subject
|
||||
messageID := env.MessageID
|
||||
inReplyTo := env.InReplyTo
|
||||
date := env.Date
|
||||
|
||||
var fromEmail, fromName string
|
||||
if len(env.From) > 0 {
|
||||
fromEmail = fmt.Sprintf("%s@%s", env.From[0].Mailbox, env.From[0].Host)
|
||||
fromName = env.From[0].Name
|
||||
}
|
||||
|
||||
// Extract body from body section
|
||||
var bodyText string
|
||||
bodyBytes := buf.FindBodySection(&imap.FetchItemBodySection{})
|
||||
if bodyBytes != nil {
|
||||
bodyText = parseEmailBody(bodyBytes)
|
||||
}
|
||||
if bodyText == "" {
|
||||
bodyText = "(no body)"
|
||||
}
|
||||
|
||||
// Find author partner by email
|
||||
var authorID int64
|
||||
pool.QueryRow(ctx,
|
||||
`SELECT id FROM res_partner WHERE LOWER(email) = LOWER($1) LIMIT 1`, fromEmail,
|
||||
).Scan(&authorID)
|
||||
|
||||
// Thread matching via In-Reply-To
|
||||
var parentModel string
|
||||
var parentResID int64
|
||||
if len(inReplyTo) > 0 && inReplyTo[0] != "" {
|
||||
pool.QueryRow(ctx,
|
||||
`SELECT model, res_id FROM mail_message
|
||||
WHERE message_id = $1 AND model IS NOT NULL AND res_id IS NOT NULL
|
||||
LIMIT 1`, inReplyTo[0],
|
||||
).Scan(&parentModel, &parentResID)
|
||||
}
|
||||
|
||||
// Fallback: match by subject
|
||||
if parentModel == "" && subject != "" {
|
||||
clean := subject
|
||||
for _, prefix := range []string{"Re: ", "RE: ", "Fwd: ", "FW: ", "AW: "} {
|
||||
clean = strings.TrimPrefix(clean, prefix)
|
||||
}
|
||||
pool.QueryRow(ctx,
|
||||
`SELECT model, res_id FROM mail_message
|
||||
WHERE subject = $1 AND model IS NOT NULL AND res_id IS NOT NULL
|
||||
ORDER BY id DESC LIMIT 1`, clean,
|
||||
).Scan(&parentModel, &parentResID)
|
||||
}
|
||||
|
||||
_, err := pool.Exec(ctx,
|
||||
`INSERT INTO mail_message
|
||||
(subject, body, message_type, email_from, author_id, model, res_id,
|
||||
date, message_id, create_uid, write_uid, create_date, write_date)
|
||||
VALUES ($1, $2, 'email', $3, $4, $5, $6, $7, $8, 1, 1, NOW(), NOW())`,
|
||||
subject, bodyText,
|
||||
fmt.Sprintf("%s <%s>", fromName, fromEmail),
|
||||
nilIfZero(authorID),
|
||||
nilIfEmpty(parentModel),
|
||||
nilIfZero(parentResID),
|
||||
date,
|
||||
messageID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func parseEmailBody(raw []byte) string {
|
||||
entity, err := gomessage.Read(strings.NewReader(string(raw)))
|
||||
if err != nil {
|
||||
return string(raw) // fallback: raw text
|
||||
}
|
||||
|
||||
if mr := entity.MultipartReader(); mr != nil {
|
||||
var htmlBody, textBody string
|
||||
for {
|
||||
part, err := mr.NextPart()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
ct, _, _ := part.Header.ContentType()
|
||||
body, _ := io.ReadAll(part.Body)
|
||||
switch {
|
||||
case strings.HasPrefix(ct, "text/html"):
|
||||
htmlBody = string(body)
|
||||
case strings.HasPrefix(ct, "text/plain"):
|
||||
textBody = string(body)
|
||||
}
|
||||
}
|
||||
if htmlBody != "" {
|
||||
return htmlBody
|
||||
}
|
||||
return textBody
|
||||
}
|
||||
|
||||
// Single part
|
||||
body, _ := io.ReadAll(entity.Body)
|
||||
return string(body)
|
||||
}
|
||||
|
||||
func nilIfZero(v int64) interface{} {
|
||||
if v == 0 {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func nilIfEmpty(v string) interface{} {
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// RegisterFetchmailCron ensures the message_id column exists for thread matching.
|
||||
func RegisterFetchmailCron(ctx context.Context, pool *pgxpool.Pool) {
|
||||
pool.Exec(ctx, `ALTER TABLE mail_message ADD COLUMN IF NOT EXISTS message_id VARCHAR(255)`)
|
||||
pool.Exec(ctx, `CREATE INDEX IF NOT EXISTS idx_mail_message_message_id ON mail_message(message_id)`)
|
||||
log.Println("fetchmail: ready (IMAP config via IMAP_HOST/IMAP_USER/IMAP_PASSWORD env vars)")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -64,8 +65,12 @@ func SendEmail(cfg *SMTPConfig, to, subject, body string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sanitize headers to prevent injection via \r\n
|
||||
sanitize := func(s string) string {
|
||||
return strings.NewReplacer("\r", "", "\n", "").Replace(s)
|
||||
}
|
||||
msg := fmt.Sprintf("From: %s\r\nTo: %s\r\nSubject: %s\r\nMIME-Version: 1.0\r\nContent-Type: text/html; charset=utf-8\r\n\r\n%s",
|
||||
cfg.From, to, subject, body)
|
||||
sanitize(cfg.From), sanitize(to), sanitize(subject), body)
|
||||
|
||||
auth := smtp.PlainAuth("", cfg.User, cfg.Password, cfg.Host)
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
|
||||
@@ -83,12 +88,17 @@ func SendEmailWithAttachments(cfg *SMTPConfig, to []string, subject, bodyHTML st
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
|
||||
boundary := "==odoo-go-boundary-42=="
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
boundary := fmt.Sprintf("==odoo-go-%x==", b)
|
||||
|
||||
sanitize := func(s string) string {
|
||||
return strings.NewReplacer("\r", "", "\n", "").Replace(s)
|
||||
}
|
||||
var msg strings.Builder
|
||||
msg.WriteString(fmt.Sprintf("From: %s\r\n", cfg.From))
|
||||
msg.WriteString(fmt.Sprintf("To: %s\r\n", strings.Join(to, ", ")))
|
||||
msg.WriteString(fmt.Sprintf("Subject: %s\r\n", subject))
|
||||
msg.WriteString(fmt.Sprintf("From: %s\r\n", sanitize(cfg.From)))
|
||||
msg.WriteString(fmt.Sprintf("To: %s\r\n", sanitize(strings.Join(to, ", "))))
|
||||
msg.WriteString(fmt.Sprintf("Subject: %s\r\n", sanitize(subject)))
|
||||
msg.WriteString("MIME-Version: 1.0\r\n")
|
||||
|
||||
if len(attachments) > 0 {
|
||||
|
||||
Reference in New Issue
Block a user